1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from copy import deepcopy 10 11import executorch.backends.vulkan.custom_ops_lib # noqa 12 13import torch 14 15from executorch.backends.vulkan.op_registry import handles_own_prepacking 16from executorch.backends.vulkan.utils import is_param_node 17 18from executorch.exir.dialects._ops import ops as exir_ops 19 20from torch.export import ExportedProgram 21 22 23def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: 24 """ 25 Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator 26 is responsible for transferring the tensor data, which is serialized with the model, 27 to a GPU tensor object during the prepacking stage of model execution. 28 29 Some operators are performance sensitive and will prefer to handle prepacking within 30 the operator. For these ops, the constant tensor data will be passed directly as an 31 argument into the operator implementation. 32 """ 33 34 def prepack_not_required(node: torch.fx.Node) -> bool: 35 if not is_param_node(program, node): 36 return True 37 38 # Annotate that this node is going to represented as a tensorref in the Vulkan 39 # compute graph. This will be useful for later graph passes. 40 node.meta["vkdg_tensorref"] = True 41 42 for user in node.users: 43 if user.op == "call_function" and handles_own_prepacking( 44 # pyre-ignore 45 user.target 46 ): 47 return True 48 49 return False 50 51 for node in program.graph_module.graph.nodes: 52 if prepack_not_required(node): 53 continue 54 55 with program.graph_module.graph.inserting_after(node): 56 prepack_node = program.graph_module.graph.create_node( 57 "call_function", 58 exir_ops.edge.et_vk.prepack.default, 59 (node,), 60 ) 61 # This pass assumes that the SpecPropPass() has already been applied 62 assert "spec" in node.meta 63 # Validate that the original node is marked as a constant. Constant tensors 64 # do not participate in memory planning. 65 assert node.meta["spec"].const 66 prepack_node.meta["val"] = node.meta["val"] 67 prepack_node.meta["spec"] = deepcopy(node.meta["spec"]) 68 # Set the mem_obj_id to -1 to indicate that this node requires a dedicated 69 # memory object. 70 prepack_node.meta["spec"].mem_obj_id = -1 71 node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y) 72 73 program.graph.eliminate_dead_code() 74 return program 75