1import copy 2import logging 3from dataclasses import dataclass 4from typing import Dict, List, Optional, Sequence, Tuple 5 6import torch 7from torch.ao.ns.fx.utils import compute_sqnr 8from torch.fx import GraphModule, Node 9from torch.nn import functional as F 10 11 12NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" 13CUSTOM_KEY = "custom" 14 15log = logging.getLogger(__name__) 16 17 18def generate_numeric_debug_handle(graph_module: GraphModule) -> None: 19 """Attach numeric_debug_handle_id for all nodes in the model except for placeholder node 20 The graph nodes of input model is modified inplace. 21 """ 22 unique_id = 0 23 for node in graph_module.graph.nodes: 24 if node.op in ["output", "placeholder"]: 25 continue 26 27 if CUSTOM_KEY not in node.meta: 28 node.meta[CUSTOM_KEY] = {} 29 30 if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: 31 node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id 32 unique_id += 1 33 34 35class OutputLogger(torch.nn.Module): 36 """ 37 Base class for capturing output values for nodes in a GraphModule, it only captures 38 Tensor output currently, but we can extend it to work for other types of inputs later if needed 39 """ 40 41 # Mark as impure so that calls to it will not be removed during DCE. 42 _is_impure = True 43 44 def __init__( 45 self, 46 debug_handle: int, 47 node_name: Optional[str] = None, 48 nn_module_stack: Optional[object] = None, 49 ) -> None: 50 super().__init__() 51 self.node_name = node_name 52 self.nn_module_stack = nn_module_stack 53 self.debug_handle = debug_handle 54 self.stats: List[torch.Tensor] = [] 55 56 def forward(self, x: object) -> object: 57 if isinstance(x, torch.Tensor): 58 self.stats.append(x.detach()) 59 return x 60 61 def __extra_repr__(self) -> str: 62 return ( 63 f"debug_handle={self.debug_handle}, node_name={self.node_name}, " 64 "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" 65 ) 66 67 68def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: 69 """For a given node, adds an OutputLogger that observes the output of that node, 70 and all its users use the OutputLogger output instead. 71 The OutputLogger will contain the debug_handle which can be used to compare 72 graphs after transforms""" 73 74 # to avoid circular dep 75 from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix 76 77 # add a logger after the node 78 with model.graph.inserting_after(node): 79 get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger") 80 logger_name = get_new_attr_name(model) 81 setattr( 82 model, 83 logger_name, 84 OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), 85 ) 86 logger_node = model.graph.call_module(logger_name, (node,), {}) 87 88 orig_users = list(node.users.keys()) 89 for user_node in orig_users: 90 if user_node is logger_node: 91 continue 92 user_node.replace_input_with(node, logger_node) 93 94 return logger_node 95 96 97def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: 98 """Add output loggers to node that has numeric_debug_handle 99 100 Args: 101 model (GraphModule): original model 102 Returns: 103 a model with output loggers for all nodes that has numeric_debug_handle_id 104 """ 105 # don't change the original model 106 model = copy.deepcopy(model) 107 for n in model.graph.nodes: 108 if ( 109 CUSTOM_KEY not in n.meta 110 or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] 111 ): 112 continue 113 numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] 114 _insert_logger(model, n, numeric_debug_handle) 115 116 model.recompile() 117 return model 118 119 120@dataclass(frozen=True) 121class QuantizationComparisonResult: 122 actual: torch.Tensor 123 ref: torch.Tensor 124 125 @property 126 def mse_loss(self) -> torch.Tensor: 127 return F.mse_loss( 128 self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) 129 ) 130 131 @property 132 def sqnr(self) -> torch.Tensor: 133 return compute_sqnr( 134 self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) 135 ) 136 137 def __repr__(self) -> str: 138 # Don't include the tensors themselves as they are quite large to print 139 # out. 140 return ( 141 f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})" 142 ) 143 144 def __post_init__(self) -> None: 145 if not isinstance(self.actual, torch.Tensor): 146 raise ValueError( 147 f"`self.actual` value must be a Tensor, got: {self.actual}" 148 ) 149 150 if not isinstance(self.ref, torch.Tensor): 151 raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}") 152 153 154@dataclass(frozen=True) 155class NodeAccuracySummary: 156 handle: int 157 actual_node_name: str 158 actual_module_stack: str 159 ref_node_name: str 160 ref_module_stack: str 161 results: Sequence[QuantizationComparisonResult] 162 163 164def _module_stack_to_str(module_stack: object) -> str: 165 """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear") 166 to "mod.foo.0.linear" 167 """ 168 if not isinstance(module_stack, dict): 169 return str(module_stack) 170 module_values_list = list(module_stack.values()) 171 if len(module_values_list) > 0: 172 owning_module = module_values_list[-1][0] 173 return str(owning_module) 174 else: 175 return str(module_stack) 176 177 178def extract_results_from_loggers( 179 model: GraphModule, 180) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]: 181 """For a given model, extract the tensors stats and related information for each debug handle. 182 183 Returns: 184 A dict is keyed by the debug_handle id and the values are a list of Tensors recorded 185 in loggers""" 186 # Results maps debug handle to a tensor list for each model being compared. 187 handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {} 188 for _name, module in model.named_children(): 189 if isinstance(module, OutputLogger) and len(module.stats) > 0: 190 handles[module.debug_handle] = ( 191 module.node_name, 192 module.nn_module_stack, 193 module.stats, 194 ) 195 196 return handles 197 198 199def compare_results( 200 ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], 201 actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], 202) -> Dict[int, NodeAccuracySummary]: 203 """Given two dict mapping from `debug_handle_id` (int) to list of tensors 204 return a map from `debug_handle_id` to `NodeAccuracySummary` that contains 205 comparison information like SQNR, MSE etc. 206 207 Args: 208 ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id 209 actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id 210 211 Returns: 212 Dict[int, NodeAccuracySummary] 213 """ 214 comparisons = {} 215 for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): 216 if debug_handle not in actual_results: 217 log.debug( 218 "Cannot compare for handle %s because it wasn't found in the transformed model", 219 debug_handle, 220 ) 221 continue 222 actual_name, actual_stack, actual_stats = actual_results[debug_handle] 223 comparisons[debug_handle] = NodeAccuracySummary( 224 handle=debug_handle, 225 actual_node_name=actual_name, 226 actual_module_stack=_module_stack_to_str(actual_stack), 227 ref_node_name=ref_name, 228 ref_module_stack=_module_stack_to_str(ref_stack), 229 results=[ 230 QuantizationComparisonResult(actual=a, ref=b) 231 for a, b in zip(actual_stats, ref_stats) 232 ], 233 ) 234 235 return comparisons 236