xref: /aosp_15_r20/external/executorch/backends/vulkan/_passes/insert_prepack_nodes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from copy import deepcopy
10
11import executorch.backends.vulkan.custom_ops_lib  # noqa
12
13import torch
14
15from executorch.backends.vulkan.op_registry import handles_own_prepacking
16from executorch.backends.vulkan.utils import is_param_node
17
18from executorch.exir.dialects._ops import ops as exir_ops
19
20from torch.export import ExportedProgram
21
22
23def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
24    """
25    Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
26    is responsible for transferring the tensor data, which is serialized with the model,
27    to a GPU tensor object during the prepacking stage of model execution.
28
29    Some operators are performance sensitive and will prefer to handle prepacking within
30    the operator. For these ops, the constant tensor data will be passed directly as an
31    argument into the operator implementation.
32    """
33
34    def prepack_not_required(node: torch.fx.Node) -> bool:
35        if not is_param_node(program, node):
36            return True
37
38        # Annotate that this node is going to represented as a tensorref in the Vulkan
39        # compute graph. This will be useful for later graph passes.
40        node.meta["vkdg_tensorref"] = True
41
42        for user in node.users:
43            if user.op == "call_function" and handles_own_prepacking(
44                # pyre-ignore
45                user.target
46            ):
47                return True
48
49        return False
50
51    for node in program.graph_module.graph.nodes:
52        if prepack_not_required(node):
53            continue
54
55        with program.graph_module.graph.inserting_after(node):
56            prepack_node = program.graph_module.graph.create_node(
57                "call_function",
58                exir_ops.edge.et_vk.prepack.default,
59                (node,),
60            )
61            # This pass assumes that the SpecPropPass() has already been applied
62            assert "spec" in node.meta
63            # Validate that the original node is marked as a constant. Constant tensors
64            # do not participate in memory planning.
65            assert node.meta["spec"].const
66            prepack_node.meta["val"] = node.meta["val"]
67            prepack_node.meta["spec"] = deepcopy(node.meta["spec"])
68            # Set the mem_obj_id to -1 to indicate that this node requires a dedicated
69            # memory object.
70            prepack_node.meta["spec"].mem_obj_id = -1
71            node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
72
73    program.graph.eliminate_dead_code()
74    return program
75