xref: /aosp_15_r20/external/executorch/backends/vulkan/serialization/vulkan_graph_builder.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import operator
9from types import NoneType
10from typing import cast, List, Optional, Union
11
12import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
13
14import torch
15
16from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
17    VkMemoryLayout,
18    VkStorageType,
19)
20from executorch.backends.vulkan.utils import (
21    is_constant,
22    is_get_attr_node,
23    is_param_node,
24)
25from executorch.exir.backend.utils import DelegateMappingBuilder
26
27from executorch.exir.tensor import TensorSpec
28from torch._export.utils import get_buffer, get_param, is_buffer, is_param
29from torch.export import ExportedProgram
30from torch.fx import Node
31
32_ScalarType = Union[bool, int, float]
33_Argument = Union[
34    Node, NoneType, _ScalarType, TensorSpec, List[_ScalarType], List[Node], str
35]
36
37logger: logging.Logger = logging.getLogger("")
38logger.setLevel(logging.INFO)
39
40
41class VkGraphBuilder:
42    def __init__(
43        self,
44        program: ExportedProgram,
45        delegate_mapping_builder: DelegateMappingBuilder,
46    ) -> None:
47        self.program = program
48        self.delegate_mapping_builder = delegate_mapping_builder
49        self.chain = []
50        self.values = []
51        self.input_ids = []
52        self.output_ids = []
53        self.const_tensors = []
54
55        # Mapping from Node to VkValue id
56        self.node_to_value_ids = {}
57
58        # For logging
59        self.seen_ops = set()
60
61    @staticmethod
62    def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
63        if torch_dtype == torch.bool:
64            return vk_graph_schema.VkDataType.BOOL
65        elif torch_dtype == torch.uint8:
66            return vk_graph_schema.VkDataType.UINT8
67        elif torch_dtype == torch.int8:
68            return vk_graph_schema.VkDataType.INT8
69        elif torch_dtype == torch.int32:
70            return vk_graph_schema.VkDataType.INT32
71        elif torch_dtype == torch.float16:
72            return vk_graph_schema.VkDataType.FLOAT16
73        elif torch_dtype == torch.float32:
74            return vk_graph_schema.VkDataType.FLOAT32
75        # Narrowing conversion for index tensor produced by max_poolNd_with_indices.
76        elif torch_dtype == torch.int64:
77            return vk_graph_schema.VkDataType.INT32
78        else:
79            raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
80
81    def get_constant(self, node: Node) -> Optional[torch.Tensor]:
82        """
83        Returns the constant associated with the given node in the exported program.
84        Returns None if the node is not a constant within the exported program
85        """
86        if is_constant(self.program, node):
87            constant_name = (
88                self.program.graph_signature.inputs_to_lifted_tensor_constants[
89                    node.name
90                ]
91            )
92            if constant_name in self.program.constants:
93                return self.program.constants[constant_name]
94            else:
95                return None
96
97        return None
98
99    def get_param_tensor(self, node: Node) -> torch.Tensor:
100        tensor = None
101        if node is None:
102            raise RuntimeError("node is None")
103        elif is_param(self.program, node):
104            tensor = get_param(self.program, node)
105        elif is_buffer(self.program, node):
106            tensor = get_buffer(self.program, node)
107        elif is_constant(self.program, node):
108            tensor = self.get_constant(node)
109        elif is_get_attr_node(node):
110            # This is a hack to support both lifted and unlifted graph
111            try:
112                tensor = getattr(node.graph.owning_module, node.target)
113            except AttributeError:
114                tensor = getattr(self.program.graph_module, node.target)
115        else:
116            raise RuntimeError(f"unsupported param type, {node.op}.")
117
118        assert tensor is not None
119        return tensor
120
121    def maybe_add_constant_tensor(self, node: Node) -> int:
122        constant_id = -1
123        if is_param_node(self.program, node):
124            constant_id = len(self.const_tensors)
125            self.const_tensors.append(self.get_param_tensor(node))
126
127        return constant_id
128
129    def create_node_value(self, node: Node) -> int:
130        # If the node has been marked as a scalar tensor, create a SymInt instead of a tensor
131        if node.meta.get("vkdg_is_scalar_tensor", False):
132            new_id = self.create_symint_value()
133            self.node_to_value_ids[node] = new_id
134            return new_id
135
136        spec = node.meta.get("spec")
137        if isinstance(spec, TensorSpec):
138            constant_id = self.maybe_add_constant_tensor(node)
139            new_id = self.create_tensor_value(spec, constant_id)
140            self.node_to_value_ids[node] = new_id
141            return new_id
142        elif isinstance(spec, list) or isinstance(spec, tuple):
143            # pyre-ignore[6]: pyre having hard time to infer Node type inside
144            # the container.
145            new_id = self.create_value_list_value(spec)
146            self.node_to_value_ids[node] = new_id
147            return new_id
148        else:
149            raise RuntimeError(f"Cannot create value for spec of type {type(spec)}")
150
151    def create_null_value(self) -> int:
152        new_id = len(self.values)
153        self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Null()))
154        return new_id
155
156    def create_scalar_value(self, scalar: _ScalarType) -> int:
157        new_id = len(self.values)
158        if isinstance(scalar, bool):
159            self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar)))
160        elif isinstance(scalar, int):
161            self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
162        elif isinstance(scalar, float):
163            self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
164        return new_id
165
166    def create_symint_value(self) -> int:
167        new_id = len(self.values)
168        self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0)))
169        return new_id
170
171    def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
172        # Negative id indicates that this tensor will have its own dedicated memory.
173        mem_obj_id = -1
174        if spec.mem_obj_id is not None:
175            mem_obj_id = spec.mem_obj_id
176
177        storage_type = VkStorageType.DEFAULT_STORAGE
178        memory_layout = VkMemoryLayout.DEFAULT_LAYOUT
179        if hasattr(spec, "vk_storage_type"):
180            # pyre-ignore[16]
181            storage_type = spec.vk_storage_type
182        if hasattr(spec, "vk_memory_layout"):
183            # pyre-ignore[16]
184            memory_layout = spec.vk_memory_layout
185
186        new_id = len(self.values)
187        self.values.append(
188            vk_graph_schema.VkValue(
189                value=vk_graph_schema.VkTensor(
190                    datatype=self.get_vk_datatype(spec.dtype),
191                    dims=spec.shape,
192                    constant_id=constant_id,
193                    mem_obj_id=mem_obj_id,
194                    storage_type=storage_type,
195                    memory_layout=memory_layout,
196                )
197            )
198        )
199        return new_id
200
201    def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
202        new_id = len(self.values)
203        if len(arg) == 0:
204            self.values.append(
205                vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[]))
206            )
207        elif isinstance(arg[0], bool):
208            self.values.append(
209                vk_graph_schema.VkValue(
210                    vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
211                )
212            )
213        elif isinstance(arg[0], int):
214            self.values.append(
215                vk_graph_schema.VkValue(
216                    vk_graph_schema.IntList(items=[cast(int, e) for e in arg])
217                )
218            )
219        elif isinstance(arg[0], float):
220            self.values.append(
221                vk_graph_schema.VkValue(
222                    vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg])
223                )
224            )
225        return new_id
226
227    def create_value_list_value(self, arg: tuple | list) -> int:
228        self.values.append(
229            vk_graph_schema.VkValue(
230                vk_graph_schema.ValueList(
231                    items=[self.get_or_create_value_for(e) for e in arg]
232                )
233            )
234        )
235        return len(self.values) - 1
236
237    def create_string_value(self, string: str) -> int:
238        new_id = len(self.values)
239        self.values.append(
240            vk_graph_schema.VkValue(vk_graph_schema.String(string_val=string))
241        )
242        return new_id
243
244    def get_or_create_value_for(self, arg: _Argument):
245        if isinstance(arg, Node):
246            # If the Node has already been processed, return the existing id.
247            if arg in self.node_to_value_ids:
248                return self.node_to_value_ids[arg]
249            return self.create_node_value(arg)
250        elif (
251            isinstance(arg, NoneType)
252            or isinstance(arg, torch.device)
253            or isinstance(arg, torch.dtype)
254            or isinstance(arg, torch.layout)
255            or isinstance(arg, torch.memory_format)
256        ):
257            return self.create_null_value()
258        elif isinstance(arg, _ScalarType):
259            return self.create_scalar_value(arg)
260        elif isinstance(arg, TensorSpec):
261            return self.create_tensor_value(arg)
262        elif isinstance(arg, list) and (
263            len(arg) == 0 or isinstance(arg[0], _ScalarType)
264        ):
265            # pyre-ignore[6]
266            return self.create_scalar_list_value(arg)
267        elif isinstance(arg, list) and isinstance(arg[0], Node):
268            return self.create_value_list_value(arg)
269        elif isinstance(arg, torch.fx.immutable_collections.immutable_list):
270            # pyre-ignore[6]
271            return self.create_value_list_value(arg)
272        elif isinstance(arg, str):
273            return self.create_string_value(arg)
274        else:
275            raise RuntimeError(f"Cannot create value for arg of type {type(arg)}")
276
277    def process_placeholder_node(self, node: Node) -> None:
278        # ignores any tensors that don't get used in any ops
279        if len(node.users) == 0:
280            return None
281        ids = self.create_node_value(node)
282        if not is_param_node(self.program, node):
283            if isinstance(ids, int):
284                self.input_ids.append(ids)
285            else:
286                self.input_ids += ids
287
288    def process_getitem_node(self, node: Node) -> None:
289        # Find ValueList id from the collection node.
290        collection_node = node.all_input_nodes[0]
291        list_id = self.node_to_value_ids[collection_node]
292
293        # Extract the target Value id from ValueList.
294        valuelist_id = node.args[1]
295        value_id = self.values[list_id].value.items[valuelist_id]
296
297        # Map Node to Value id.
298        self.node_to_value_ids[node] = value_id
299
300    def process_call_function_node(self, node) -> None:
301        operator_call_args = []
302
303        self.seen_ops.add(node.target)
304
305        for i, schema_arg in enumerate(node.target._schema.arguments):
306            if not schema_arg.kwarg_only and i < len(node.args):
307                function_arg = node.args[i]
308            elif schema_arg.name in node.kwargs:
309                function_arg = node.kwargs[schema_arg.name]
310            else:
311                function_arg = schema_arg.default_value
312
313            # Create a Value for each function argument. If the argument has been
314            # previously encountered, then use the existing Value id.
315            operator_call_args.append(self.get_or_create_value_for(function_arg))
316
317        # Add output node
318        operator_call_args.append(self.create_node_value(node))
319        operator_node_id = (
320            0
321            if not self.delegate_mapping_builder
322            else self.delegate_mapping_builder.insert_delegate_mapping_entry(node)
323        )
324        self.chain.append(
325            vk_graph_schema.OperatorCall(
326                node_id=operator_node_id,  # pyre-ignore[6]: this is going to be an int
327                name=node.target.__name__,
328                args=operator_call_args,
329            ),
330        )
331
332    def process_getattr_node(self, node: Node) -> None:
333        self.create_node_value(node)
334
335    def process_output_node(self, node: Node) -> None:
336        for out_node in node.all_input_nodes:
337            if out_node not in self.node_to_value_ids:
338                raise AssertionError(
339                    "Cannot find input to output node in node_to_value_ids. This means "
340                    "the output node is being serialized before its corresponding "
341                    "internal node which is not allowed."
342                )
343            self.output_ids.append(self.node_to_value_ids[out_node])
344
345    def process_node(self, node: Node, call_node_debug_hdl: int) -> None:
346        if node.op == "placeholder":
347            self.process_placeholder_node(node)
348        elif node.op == "call_function":
349            if node.target == operator.getitem:
350                self.process_getitem_node(node)
351            else:
352                node.meta["debug_handle"] = call_node_debug_hdl
353                self.process_call_function_node(node)
354        elif node.op == "get_attr":
355            self.process_getattr_node(node)
356        elif node.op == "output":
357            self.process_output_node(node)
358        else:
359            raise AssertionError(f"Unsupported node op: {node.op}")
360
361    def build_graph(self) -> vk_graph_schema.VkGraph:
362        call_node_debug_hdl = 0
363        for node in self.program.graph_module.graph.nodes:
364            self.process_node(node, call_node_debug_hdl)
365            call_node_debug_hdl += 1
366
367        logger.info("Operators included in this Vulkan partition: ")
368        for op in self.seen_ops:
369            logger.info(f"    {op.__name__}")
370
371        return vk_graph_schema.VkGraph(
372            version="0",
373            chain=self.chain,
374            values=self.values,
375            input_ids=self.input_ids,
376            output_ids=self.output_ids,
377            constants=[],
378            shaders=[],
379        )
380