xref: /aosp_15_r20/external/executorch/devtools/backend_debug/delegation_info.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import re
8from collections import defaultdict
9from dataclasses import asdict, dataclass
10from typing import Dict
11
12import pandas as pd
13import torch
14
15
16# Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe()
17# which describes the summarized delegation information grouped by each operator type
18_OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs"
19_OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs"
20
21
22@dataclass
23class DelegationBreakdown:
24    """
25    DelegationBreakdown contains the number of delegated and non-delegated nodes
26    of the operator type op_type.
27
28    Args:
29        delegated: The number of delegated nodes.
30        non_delegated: The number of non-delegated nodes.
31    """
32
33    op_type: str = ""
34    delegated: int = 0
35    non_delegated: int = 0
36
37
38@dataclass
39class DelegationInfo:
40    """
41    DelegationInfo contains information of a delegated graph module.
42
43    Args:
44        num_delegated_subgraphs: The number of delegated subgraphs.
45        num_delegated_nodes: The number of delegated nodes.
46        num_non_delegated_nodes: The number of non-delegated nodes.
47        delegation_by_operator: A dictionary of operator type to DelegationBreakdown.
48    """
49
50    num_delegated_subgraphs: int
51    num_delegated_nodes: int
52    num_non_delegated_nodes: int
53    delegation_by_operator: Dict[str, DelegationBreakdown]
54
55    def get_summary(self) -> str:
56        """
57        Get a summary of the delegation information in string format.
58
59        Args:
60            None
61
62        Returns:
63            A string containing information of some class attributes for easy print-out.
64        """
65
66        # Assemble and return the summary string
67        summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n"
68        summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n"
69        summary_str += (
70            f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n"
71        )
72        return summary_str
73
74    def get_operator_delegation_dataframe(self) -> pd.DataFrame:
75        """
76        Get the delegation information grouped by operator type in a pandas DataFrame.
77
78        Args:
79            None
80
81        Returns:
82            Returns a pandas DataFrame containing the following columns:
83            - op_type: The operator type, with the last row being "Total".
84            - occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs.
85            - occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs.
86            With the last row being the total number of delegated and non-delegated occurrences of each op_type.
87        """
88
89        # Convert the dict to a dataframe
90        list_of_dicts = [
91            asdict(breakdown) for breakdown in self.delegation_by_operator.values()
92        ]
93        df = pd.DataFrame(list_of_dicts)
94        # Rename columns for better understandability
95        df = df.rename(
96            columns={
97                "delegated": _OCCURRENCES_IN_DELEGATED_GRAPHS,
98                "non_delegated": _OCCURRENCES_IN_NON_DELEGATED_GRAPHS,
99            }
100        )
101        df = df.sort_values(by="op_type", ignore_index=True)
102
103        # Add a Total row at the bottom
104        total_delegated_nodes = df[_OCCURRENCES_IN_DELEGATED_GRAPHS].sum()
105        total_non_delegated_nodes = df[_OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum()
106        df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes]
107
108        return df
109
110
111def get_delegation_info(
112    graph_module: torch.fx.GraphModule,
113) -> DelegationInfo:
114    """
115    Util function to get the delegation information of the given graph module.
116
117    Args:
118        graph_module: The lowered graph module to get the delegation information from.
119
120    Returns:
121        Return a DelegationInfo object containing the delegation information.
122    """
123
124    def _get_op_type(node_name: str) -> str:
125        # node_name is in format <op_type> or <op_type>_x in which x is an integer suffix.
126        return re.sub(r"_[\d]+$", "", node_name)
127
128    op_occurrences_dict = defaultdict(lambda: DelegationBreakdown())
129
130    def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
131        op_type = _get_op_type(node_name)
132        op_occurrences_dict[op_type].op_type = op_type
133        if delegated:
134            op_occurrences_dict[op_type].delegated += 1
135        else:
136            op_occurrences_dict[op_type].non_delegated += 1
137
138    delegated_subgraph_counter = 0
139
140    lowered_module_dict = {
141        node.name: getattr(graph_module, node.name)
142        for node in graph_module.graph.nodes
143        if node.op == "get_attr" and node.name.startswith("lowered_module_")
144    }
145
146    for node in graph_module.graph.nodes:
147        if (
148            node.op == "call_function"
149            and _get_op_type(node.name) != "executorch_call_delegate"
150        ):
151            # Non-delegated node
152            _insert_op_occurrences_dict(node_name=node.name, delegated=False)
153        # Check if the node is a lowered module
154        if node.op == "get_attr" and node.name.startswith("lowered_module_"):
155            lowered_module = lowered_module_dict[node.name]
156            delegated_subgraph_counter += 1
157            for node_in_lowered_module in lowered_module.original_module.graph.nodes:
158                if node_in_lowered_module.op == "call_function":
159                    # Delegated node
160                    _insert_op_occurrences_dict(
161                        node_name=node_in_lowered_module.name, delegated=True
162                    )
163
164    # Calculate the total number of delegated and non-delegated nodes
165    num_delegated_nodes = 0
166    num_non_delegated_nodes = 0
167    for value in op_occurrences_dict.values():
168        num_delegated_nodes += value.delegated
169        num_non_delegated_nodes += value.non_delegated
170
171    return DelegationInfo(
172        num_delegated_nodes=num_delegated_nodes,
173        num_non_delegated_nodes=num_non_delegated_nodes,
174        num_delegated_subgraphs=delegated_subgraph_counter,
175        delegation_by_operator=op_occurrences_dict,
176    )
177