xref: /aosp_15_r20/external/executorch/exir/passes/debug_handle_generator_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from executorch.exir.graph_module import get_control_flow_submodules
8from executorch.exir.pass_base import ExportPass
9from torch.export import ExportedProgram
10from torch.fx import GraphModule
11from torch.fx.passes.infra.pass_base import PassResult
12
13
14class DebugHandleGeneratorPass(ExportPass):
15    def call(self, graph_module: GraphModule) -> PassResult:
16        """Lower a quantized reference model (with reference quantized operator patterns)
17        to executorch backend, that has a canonical set of quantized operators
18        """
19
20        queue = [graph_module]
21        index = 1
22        # bfs to traverse all modules including control flow submodules to attached debug handle id
23        while queue:
24            current_graph_module = queue.pop(0)
25            for node in current_graph_module.graph.nodes:
26                node.meta["debug_handle"] = index
27                index += 1
28            control_flow_submodules = [
29                submodule
30                for _, submodule, _ in get_control_flow_submodules(current_graph_module)
31            ]
32            queue.extend(control_flow_submodules)
33        return PassResult(graph_module, True)
34
35
36def generate_missing_debug_handles(ep: ExportedProgram):
37    """
38    This pass is used to generate missing debug handles for the graph module and its submodules.
39    """
40
41    def get_control_flow_submodules_list(graph_module):
42        return [
43            submodule for _, submodule, _ in get_control_flow_submodules(graph_module)
44        ]
45
46    max_handle = 0
47    queue = [ep.graph_module]
48
49    while queue:
50        current_graph_module = queue.pop(0)
51        for node in current_graph_module.graph.nodes:
52            if "debug_handle" in node.meta:
53                max_handle = max(max_handle, node.meta["debug_handle"])
54        control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
55        queue.extend(control_flow_submodules)
56
57    queue = [ep.graph_module]
58    while queue:
59        current_graph_module = queue.pop(0)
60        for node in current_graph_module.graph.nodes:
61            if node.meta.get("debug_handle", 0) in (0, None):
62                node.meta["debug_handle"] = max_handle + 1
63                max_handle += 1
64        control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
65        queue.extend(control_flow_submodules)
66