xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/_numeric_debugger.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2import logging
3from dataclasses import dataclass
4from typing import Dict, List, Optional, Sequence, Tuple
5
6import torch
7from torch.ao.ns.fx.utils import compute_sqnr
8from torch.fx import GraphModule, Node
9from torch.nn import functional as F
10
11
12NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
13CUSTOM_KEY = "custom"
14
15log = logging.getLogger(__name__)
16
17
18def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
19    """Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
20    The graph nodes of input model is modified inplace.
21    """
22    unique_id = 0
23    for node in graph_module.graph.nodes:
24        if node.op in ["output", "placeholder"]:
25            continue
26
27        if CUSTOM_KEY not in node.meta:
28            node.meta[CUSTOM_KEY] = {}
29
30        if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]:
31            node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
32            unique_id += 1
33
34
35class OutputLogger(torch.nn.Module):
36    """
37    Base class for capturing output values for nodes in a GraphModule, it only captures
38    Tensor output currently, but we can extend it to work for other types of inputs later if needed
39    """
40
41    # Mark as impure so that calls to it will not be removed during DCE.
42    _is_impure = True
43
44    def __init__(
45        self,
46        debug_handle: int,
47        node_name: Optional[str] = None,
48        nn_module_stack: Optional[object] = None,
49    ) -> None:
50        super().__init__()
51        self.node_name = node_name
52        self.nn_module_stack = nn_module_stack
53        self.debug_handle = debug_handle
54        self.stats: List[torch.Tensor] = []
55
56    def forward(self, x: object) -> object:
57        if isinstance(x, torch.Tensor):
58            self.stats.append(x.detach())
59        return x
60
61    def __extra_repr__(self) -> str:
62        return (
63            f"debug_handle={self.debug_handle}, node_name={self.node_name}, "
64            "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})"
65        )
66
67
68def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:
69    """For a given node, adds an OutputLogger that observes the output of that node,
70    and all its users use the OutputLogger output instead.
71    The OutputLogger will contain the debug_handle which can be used to compare
72    graphs after transforms"""
73
74    # to avoid circular dep
75    from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
76
77    # add a logger after the node
78    with model.graph.inserting_after(node):
79        get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger")
80        logger_name = get_new_attr_name(model)
81        setattr(
82            model,
83            logger_name,
84            OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")),
85        )
86        logger_node = model.graph.call_module(logger_name, (node,), {})
87
88    orig_users = list(node.users.keys())
89    for user_node in orig_users:
90        if user_node is logger_node:
91            continue
92        user_node.replace_input_with(node, logger_node)
93
94    return logger_node
95
96
97def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
98    """Add output loggers to node that has numeric_debug_handle
99
100    Args:
101        model (GraphModule): original model
102    Returns:
103        a model with output loggers for all nodes that has numeric_debug_handle_id
104    """
105    # don't change the original model
106    model = copy.deepcopy(model)
107    for n in model.graph.nodes:
108        if (
109            CUSTOM_KEY not in n.meta
110            or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY]
111        ):
112            continue
113        numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
114        _insert_logger(model, n, numeric_debug_handle)
115
116    model.recompile()
117    return model
118
119
120@dataclass(frozen=True)
121class QuantizationComparisonResult:
122    actual: torch.Tensor
123    ref: torch.Tensor
124
125    @property
126    def mse_loss(self) -> torch.Tensor:
127        return F.mse_loss(
128            self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
129        )
130
131    @property
132    def sqnr(self) -> torch.Tensor:
133        return compute_sqnr(
134            self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
135        )
136
137    def __repr__(self) -> str:
138        # Don't include the tensors themselves as they are quite large to print
139        # out.
140        return (
141            f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})"
142        )
143
144    def __post_init__(self) -> None:
145        if not isinstance(self.actual, torch.Tensor):
146            raise ValueError(
147                f"`self.actual` value must be a Tensor, got: {self.actual}"
148            )
149
150        if not isinstance(self.ref, torch.Tensor):
151            raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
152
153
154@dataclass(frozen=True)
155class NodeAccuracySummary:
156    handle: int
157    actual_node_name: str
158    actual_module_stack: str
159    ref_node_name: str
160    ref_module_stack: str
161    results: Sequence[QuantizationComparisonResult]
162
163
164def _module_stack_to_str(module_stack: object) -> str:
165    """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear")
166    to "mod.foo.0.linear"
167    """
168    if not isinstance(module_stack, dict):
169        return str(module_stack)
170    module_values_list = list(module_stack.values())
171    if len(module_values_list) > 0:
172        owning_module = module_values_list[-1][0]
173        return str(owning_module)
174    else:
175        return str(module_stack)
176
177
178def extract_results_from_loggers(
179    model: GraphModule,
180) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]:
181    """For a given model, extract the tensors stats and related information for each debug handle.
182
183    Returns:
184        A dict is keyed by the debug_handle id and the values are a list of Tensors recorded
185        in loggers"""
186    # Results maps debug handle to a tensor list for each model being compared.
187    handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {}
188    for _name, module in model.named_children():
189        if isinstance(module, OutputLogger) and len(module.stats) > 0:
190            handles[module.debug_handle] = (
191                module.node_name,
192                module.nn_module_stack,
193                module.stats,
194            )
195
196    return handles
197
198
199def compare_results(
200    ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
201    actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
202) -> Dict[int, NodeAccuracySummary]:
203    """Given two dict mapping from `debug_handle_id` (int) to list of tensors
204    return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
205    comparison information like SQNR, MSE etc.
206
207    Args:
208        ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id
209        actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id
210
211    Returns:
212        Dict[int, NodeAccuracySummary]
213    """
214    comparisons = {}
215    for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items():
216        if debug_handle not in actual_results:
217            log.debug(
218                "Cannot compare for handle %s because it wasn't found in the transformed model",
219                debug_handle,
220            )
221            continue
222        actual_name, actual_stack, actual_stats = actual_results[debug_handle]
223        comparisons[debug_handle] = NodeAccuracySummary(
224            handle=debug_handle,
225            actual_node_name=actual_name,
226            actual_module_stack=_module_stack_to_str(actual_stack),
227            ref_node_name=ref_name,
228            ref_module_stack=_module_stack_to_str(ref_stack),
229            results=[
230                QuantizationComparisonResult(actual=a, ref=b)
231                for a, b in zip(actual_stats, ref_stats)
232            ],
233        )
234
235    return comparisons
236