1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import re 8from collections import defaultdict 9from dataclasses import asdict, dataclass 10from typing import Dict 11 12import pandas as pd 13import torch 14 15 16# Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe() 17# which describes the summarized delegation information grouped by each operator type 18_OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs" 19_OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs" 20 21 22@dataclass 23class DelegationBreakdown: 24 """ 25 DelegationBreakdown contains the number of delegated and non-delegated nodes 26 of the operator type op_type. 27 28 Args: 29 delegated: The number of delegated nodes. 30 non_delegated: The number of non-delegated nodes. 31 """ 32 33 op_type: str = "" 34 delegated: int = 0 35 non_delegated: int = 0 36 37 38@dataclass 39class DelegationInfo: 40 """ 41 DelegationInfo contains information of a delegated graph module. 42 43 Args: 44 num_delegated_subgraphs: The number of delegated subgraphs. 45 num_delegated_nodes: The number of delegated nodes. 46 num_non_delegated_nodes: The number of non-delegated nodes. 47 delegation_by_operator: A dictionary of operator type to DelegationBreakdown. 48 """ 49 50 num_delegated_subgraphs: int 51 num_delegated_nodes: int 52 num_non_delegated_nodes: int 53 delegation_by_operator: Dict[str, DelegationBreakdown] 54 55 def get_summary(self) -> str: 56 """ 57 Get a summary of the delegation information in string format. 58 59 Args: 60 None 61 62 Returns: 63 A string containing information of some class attributes for easy print-out. 64 """ 65 66 # Assemble and return the summary string 67 summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n" 68 summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n" 69 summary_str += ( 70 f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n" 71 ) 72 return summary_str 73 74 def get_operator_delegation_dataframe(self) -> pd.DataFrame: 75 """ 76 Get the delegation information grouped by operator type in a pandas DataFrame. 77 78 Args: 79 None 80 81 Returns: 82 Returns a pandas DataFrame containing the following columns: 83 - op_type: The operator type, with the last row being "Total". 84 - occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs. 85 - occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs. 86 With the last row being the total number of delegated and non-delegated occurrences of each op_type. 87 """ 88 89 # Convert the dict to a dataframe 90 list_of_dicts = [ 91 asdict(breakdown) for breakdown in self.delegation_by_operator.values() 92 ] 93 df = pd.DataFrame(list_of_dicts) 94 # Rename columns for better understandability 95 df = df.rename( 96 columns={ 97 "delegated": _OCCURRENCES_IN_DELEGATED_GRAPHS, 98 "non_delegated": _OCCURRENCES_IN_NON_DELEGATED_GRAPHS, 99 } 100 ) 101 df = df.sort_values(by="op_type", ignore_index=True) 102 103 # Add a Total row at the bottom 104 total_delegated_nodes = df[_OCCURRENCES_IN_DELEGATED_GRAPHS].sum() 105 total_non_delegated_nodes = df[_OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum() 106 df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes] 107 108 return df 109 110 111def get_delegation_info( 112 graph_module: torch.fx.GraphModule, 113) -> DelegationInfo: 114 """ 115 Util function to get the delegation information of the given graph module. 116 117 Args: 118 graph_module: The lowered graph module to get the delegation information from. 119 120 Returns: 121 Return a DelegationInfo object containing the delegation information. 122 """ 123 124 def _get_op_type(node_name: str) -> str: 125 # node_name is in format <op_type> or <op_type>_x in which x is an integer suffix. 126 return re.sub(r"_[\d]+$", "", node_name) 127 128 op_occurrences_dict = defaultdict(lambda: DelegationBreakdown()) 129 130 def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None: 131 op_type = _get_op_type(node_name) 132 op_occurrences_dict[op_type].op_type = op_type 133 if delegated: 134 op_occurrences_dict[op_type].delegated += 1 135 else: 136 op_occurrences_dict[op_type].non_delegated += 1 137 138 delegated_subgraph_counter = 0 139 140 lowered_module_dict = { 141 node.name: getattr(graph_module, node.name) 142 for node in graph_module.graph.nodes 143 if node.op == "get_attr" and node.name.startswith("lowered_module_") 144 } 145 146 for node in graph_module.graph.nodes: 147 if ( 148 node.op == "call_function" 149 and _get_op_type(node.name) != "executorch_call_delegate" 150 ): 151 # Non-delegated node 152 _insert_op_occurrences_dict(node_name=node.name, delegated=False) 153 # Check if the node is a lowered module 154 if node.op == "get_attr" and node.name.startswith("lowered_module_"): 155 lowered_module = lowered_module_dict[node.name] 156 delegated_subgraph_counter += 1 157 for node_in_lowered_module in lowered_module.original_module.graph.nodes: 158 if node_in_lowered_module.op == "call_function": 159 # Delegated node 160 _insert_op_occurrences_dict( 161 node_name=node_in_lowered_module.name, delegated=True 162 ) 163 164 # Calculate the total number of delegated and non-delegated nodes 165 num_delegated_nodes = 0 166 num_non_delegated_nodes = 0 167 for value in op_occurrences_dict.values(): 168 num_delegated_nodes += value.delegated 169 num_non_delegated_nodes += value.non_delegated 170 171 return DelegationInfo( 172 num_delegated_nodes=num_delegated_nodes, 173 num_non_delegated_nodes=num_non_delegated_nodes, 174 num_delegated_subgraphs=delegated_subgraph_counter, 175 delegation_by_operator=op_occurrences_dict, 176 ) 177