xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import enum
4import operator
5from typing import Callable, Dict, List, Optional, Set, Tuple, Union
6
7import torch
8import torch.ao.nn.intrinsic.quantized as nniq
9import torch.ao.nn.quantized as nnq
10import torch.nn as nn
11from torch.ao.quantization import FakeQuantizeBase, ObserverBase
12from torch.ao.quantization.observer import _is_activation_post_process
13from torch.ao.quantization.utils import getattr_from_fqn
14from torch.fx import GraphModule
15from torch.fx.graph import Node
16
17from .ns_types import NSNodeTargetType, NSResultsType
18
19
20toq = torch.ops.quantized
21
22
23# TODO(future PR): consider deleting this enum and using the torch types
24# directly.  This might be tricky because it is not a one to one mapping.
25class NodeInputOrOutputType(enum.Enum):
26    FP32 = enum.auto()  # torch.float
27    INT8 = enum.auto()  # torch.qint8 or torch.quint8
28    FP16 = enum.auto()  # torch.float16
29    UNKNOWN = enum.auto()  # we cannot determine input/output dtype
30    # TODO(future PR): while these functions can support multiple dtypes,
31    #   for the purposes of numerical debugging we want to get the actual
32    #   dtype used in the model. We will likely need some kind of dtype
33    #   propagation to estimate this.
34    FP32_OR_INT8 = enum.auto()  # either torch.float or torch.quint8 or torch.qint8
35    # TODO(future PRs): dynamic quant, fake quant, etc
36
37
38def get_node_first_input_and_output_type(
39    node: Node,
40    gm: GraphModule,
41    logger_cls: Callable,
42    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
43) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
44    # TODO(future PR): clean this up
45    FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
46    FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
47    FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
48    FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
49    MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
50    MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
51    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
52    METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
53
54    if node.op == "call_function":
55        if node.target in FUNS_IO_TYPE_FP32:
56            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
57        if node.target in FUNS_IO_TYPE_FP16:
58            return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
59        elif node.target in FUNS_IO_TYPE_INT8:
60            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
61        elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
62            first_arg = get_normalized_nth_input(node, gm, 0)
63            assert isinstance(first_arg, Node)
64            (
65                _prev_node_input_type,
66                prev_node_output_type,
67            ) = get_node_first_input_and_output_type(
68                first_arg, gm, logger_cls, node_type_to_io_type_map
69            )
70            return (prev_node_output_type, prev_node_output_type)
71        else:
72            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
73
74    elif node.op == "call_module":
75        assert node.op == "call_module"
76        assert isinstance(node.target, str)
77        mod = getattr_from_fqn(gm, node.target)
78        is_known_fp32_or_int8_input_module = any(
79            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
80        )
81        if (
82            isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase))  # type: ignore[arg-type]
83            or is_known_fp32_or_int8_input_module
84        ):
85            # A logger or observer's input and output type is the output
86            # type of the preceding node.
87            first_arg = get_normalized_nth_input(node, gm, 0)
88            assert isinstance(first_arg, Node)
89            (
90                _prev_node_input_type,
91                prev_node_output_type,
92            ) = get_node_first_input_and_output_type(
93                first_arg, gm, logger_cls, node_type_to_io_type_map
94            )
95            return (prev_node_output_type, prev_node_output_type)
96        is_known_fp32_input_module = any(
97            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32  # type: ignore[arg-type]
98        )
99        is_known_int8_input_module = any(
100            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8  # type: ignore[arg-type]
101        )
102        if is_known_fp32_input_module:
103            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
104        elif is_known_int8_input_module:
105            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
106        else:
107            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
108
109    elif node.op == "call_method":
110        if node.target == "dequantize":
111            # Dequantize is a special node because it allows multiple input types.
112            # So, we look up the output type of the previous node and return that
113            # as the input type of this node instance.
114            prev_node = get_normalized_nth_input(node, gm, 0)
115            assert isinstance(prev_node, Node)
116            (
117                _prev_node_input_type,
118                prev_node_output_type,
119            ) = get_node_first_input_and_output_type(
120                prev_node, gm, logger_cls, node_type_to_io_type_map
121            )
122            return (prev_node_output_type, NodeInputOrOutputType.FP32)
123
124        elif node.target == "to":
125            # to is a special node because it allows multiple input types.
126            # So, we look up the output type of the previous node and return that
127            # as the input type of this node instance. We also look up the target
128            # of to and return the correct output type.
129            prev_node = get_normalized_nth_input(node, gm, 0)
130            assert isinstance(prev_node, Node)
131            (
132                _prev_node_input_type,
133                prev_node_output_type,
134            ) = get_node_first_input_and_output_type(
135                prev_node, gm, logger_cls, node_type_to_io_type_map
136            )
137
138            cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
139            assert (
140                cur_node_dtype_target is torch.float16
141            ), f"{cur_node_dtype_target} handling needs to be added"
142
143            return (prev_node_output_type, NodeInputOrOutputType.FP16)
144
145        elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
146            first_arg = get_normalized_nth_input(node, gm, 0)
147            assert isinstance(first_arg, Node)
148            (
149                _prev_node_input_type,
150                prev_node_output_type,
151            ) = get_node_first_input_and_output_type(
152                first_arg, gm, logger_cls, node_type_to_io_type_map
153            )
154            return (prev_node_output_type, prev_node_output_type)
155
156        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
157    else:
158        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
159
160
161def get_node_input_qparams(
162    node: Node,
163    gm: GraphModule,
164    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
165) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
166    """
167    Returns the qparams (scale, zero_point) of the first input to `node`,
168    if they can be inferred from the graph.
169    """
170    prev_node = get_normalized_nth_input(node, gm, 0)
171
172    if not isinstance(prev_node, Node):
173        return None
174
175    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
176
177    def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
178        scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
179        zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
180        assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
181        assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
182        scale_obj = getattr_from_fqn(gm, scale_node.target)
183        zp_obj = getattr_from_fqn(gm, zp_node.target)
184        return (scale_obj, zp_obj)
185
186    if prev_node.op == "call_function":
187        # quantize - read the args directly
188        if prev_node.target == torch.quantize_per_tensor:
189            return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
190        elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
191            return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
192
193        return None
194        # TODO(future PR): handle more functionals
195        # TODO(future PR): handle functional ops which inherit qparams from input
196
197    elif prev_node.op == "call_module":
198        # get type of the module
199        assert isinstance(prev_node.target, str)
200        module_obj = getattr_from_fqn(gm, prev_node.target)
201        if isinstance(
202            module_obj,
203            (
204                nnq.Linear,
205                nnq.Conv1d,
206                nnq.Conv2d,
207                nniq.ConvReLU2d,
208                nnq.Conv3d,
209                nnq.BatchNorm2d,
210                nnq.BatchNorm3d,
211                nnq.ConvTranspose1d,
212                nnq.ConvTranspose2d,
213                nnq.ELU,
214                nnq.GroupNorm,
215                nnq.InstanceNorm1d,
216                nnq.InstanceNorm2d,
217                nnq.InstanceNorm3d,
218                nnq.LayerNorm,
219                nnq.Hardswish,
220                nnq.LeakyReLU,
221                nnq.ReLU6,
222                nniq.BNReLU2d,
223                nniq.BNReLU3d,
224                nniq.ConvReLU1d,
225                nniq.ConvReLU2d,
226                nniq.ConvReLU3d,
227                nniq.LinearReLU,
228            ),
229        ):
230            return (module_obj.scale, module_obj.zero_point)  # type: ignore[return-value]
231
232        is_known_fp32_or_int8_input_module = any(
233            isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
234        )
235        if is_known_fp32_or_int8_input_module:
236            return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
237
238    return None
239
240
241def return_first_non_observer_node(
242    node: Node,
243    gm: GraphModule,
244) -> Node:
245    """
246    If node is not an observer, returns it.  If node is an observer,
247    navigates up the graph and returns the first parent which is not an
248    observer.  For example,
249
250    graph: (node_non_obs), node = node_non_obs : returns node_non_obs
251    graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
252    graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
253    """
254    if node.op == "call_module":
255        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
256        if _is_activation_post_process(node_obj):
257            assert len(node.args) == 1
258            assert isinstance(node.args[0], Node)
259            node = node.args[0]
260            # code duplication intended, not worth refactoring
261            assert isinstance(node.target, str)
262            node_obj = getattr_from_fqn(gm, node.target)
263            if _is_activation_post_process(node_obj):
264                assert len(node.args) == 1
265                assert isinstance(node.args[0], Node)
266                node = node.args[0]
267    return node
268
269
270def get_number_of_non_param_args(
271    node: Node,
272    gm: GraphModule,
273) -> int:
274    """
275    Assumes that all non-param args occur first. Returns the number of
276    non-param args expected for a node.  For example, for
277
278      F.linear(x, weight, bias)
279
280    Returns 1, because x is a non-param arg and weight and bias are params.
281    For
282
283      lstm_mod(x, hid)
284
285    Returns 2, because both x and hid are non-param args.
286    """
287    if node.op == "call_module":
288        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
289        if isinstance(node_obj, nn.LSTM):
290            return 2
291
292    # default is 1
293    return 1
294
295
296def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
297    """
298    Returns the indices of args of the node which we should attach
299    loggers to, if input logging is enabled.
300
301    For example,
302    * for (x + y), returns [0, 1]
303    * for (1 + y), returns [1]
304    * for (x + 1), returns [0]
305    * for (linear(x, w, b)) returns [0]
306    * by default, returns [0]
307    """
308    if len(node.args) == 0:
309        return []
310    if node.op == "call_function" and (
311        # TODO(future PR): use relationship map instead of hardcoding
312        node.target in (torch.add, torch.ops.quantized.add, operator.add)
313        or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
314    ):
315        result = []
316        for i in range(2):
317            if type(node.args[i]) == Node:
318                result.append(i)
319        return result
320    return [0]
321
322
323def get_target_type_str(node: Node, gm: GraphModule) -> str:
324    """
325    Returns a string representation of the type of the function or module
326    pointed to by this node, or '' for other node types.
327    """
328    target_type = ""
329    if node.op in ("call_function", "call_method"):
330        target_type = torch.typename(node.target)
331    elif node.op == "call_module":
332        assert isinstance(node.target, str)
333        target_mod = getattr_from_fqn(gm, node.target)
334        target_type = torch.typename(target_mod)
335    return target_type
336
337
338def rekey_logger_info_on_node_name_of_model(
339    results: NSResultsType,
340    model_name: str,
341) -> NSResultsType:
342    """
343    Rekeys the layer name of a results dictionary to use node names
344    from `model_name`.
345
346    For example, transforms
347
348        {'base_op_1_0': {'node_output': {'model_a':
349          [{'ref_node_name': 'linear1', ...}]}}}
350
351    into
352
353        {'linear1': {'node_output': {'model_a':
354          [{'ref_node_name': 'linear1', ...}]}}}
355
356    Note: we cannot use these node names directly because they are not
357    guaranteed to be consistent across models. This is why we extract
358    the results first and rekey afterwards.
359    """
360    new_results = {}
361    for old_layer_name, result_type_to_results in results.items():
362        new_layer_name = None
363        for model_name_to_results in result_type_to_results.values():
364            for cur_model_name, list_of_results in model_name_to_results.items():
365                if cur_model_name == model_name:
366                    assert len(list_of_results)
367                    new_layer_name = list_of_results[0]["ref_node_name"]
368                else:
369                    continue
370        if new_layer_name is not None:
371            new_results[new_layer_name] = result_type_to_results
372        else:
373            new_results[old_layer_name] = result_type_to_results
374    return new_results
375
376
377def maybe_add_missing_fqns(results: NSResultsType) -> None:
378    """
379    If `fqn` entries are filled in for one of the models in `results`, copies
380    them over to any models which do not have them filled out.
381
382    A common use case benefitting from this is comparing a model prepared by
383    quantization to a quantized model. In this case, the model prepared by
384    quantization would have `fqn` entries, and the quantized model would not.
385    """
386
387    # Check in the first result to find any model with fqn entries defined.
388    model_name_with_fqns = None
389    for result_type_to_results in results.values():
390        for model_name_to_results in result_type_to_results.values():
391            for model_name, model_results in model_name_to_results.items():
392                if len(model_results) > 0:
393                    if model_results[0]["fqn"] is not None:
394                        model_name_with_fqns = model_name
395                        break
396            break
397        break
398
399    if model_name_with_fqns:
400        for result_type_to_results in results.values():
401            for model_name_to_results in result_type_to_results.values():
402                ref_model_results = model_name_to_results[model_name_with_fqns]
403                for model_name, model_results in model_name_to_results.items():
404                    if model_name == model_name_with_fqns:
405                        continue
406                    for i in range(len(model_results)):
407                        fqn = ref_model_results[i]["fqn"]
408                        model_results[i]["fqn"] = fqn
409
410
411def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
412    def inner(*args, **kwargs):
413        a0, a1, *a_other = args
414
415        if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
416            isinstance(a0, list) and isinstance(a1, list)
417        ):
418            results = []
419            for el0, el1 in zip(a0, a1):
420                new_args = (el0, el1, *a_other)
421                results.append(inner(*new_args, **kwargs))
422            return results
423
424        elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
425            if a0.is_quantized:
426                a0 = a0.dequantize()
427            if a1.is_quantized:
428                a1 = a1.dequantize()
429
430        # for the purposes of this util, only handle floats
431        if a0.dtype != torch.float or a1.dtype != torch.float:
432            return None
433
434        new_args = (a0, a1, *a_other)
435        return f(*new_args, **kwargs)
436
437    return inner
438
439
440@maybe_dequantize_first_two_tensor_args_and_handle_tuples
441def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
442    """
443    Computes the SQNR between `x` and `y`.
444
445    Args:
446        x: Tensor or tuple of tensors
447        y: Tensor or tuple of tensors
448
449    Return:
450        float or tuple of floats
451    """
452    Ps = torch.norm(x)
453    Pn = torch.norm(x - y)
454    return 20 * torch.log10(Ps / Pn)
455
456
457@maybe_dequantize_first_two_tensor_args_and_handle_tuples
458def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
459    """
460    Computes the normalized L2 error between `x` and `y`.
461
462    Args:
463        x: Tensor or tuple of tensors
464        y: Tensor or tuple of tensors
465
466    Return:
467        float or tuple of floats
468    """
469    return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
470
471
472@maybe_dequantize_first_two_tensor_args_and_handle_tuples
473def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
474    """
475    Computes the cosine similarity between `x` and `y`.
476
477    Args:
478        x: Tensor or tuple of tensors
479        y: Tensor or tuple of tensors
480
481    Return:
482        float or tuple of floats
483    """
484    # For convolutions, the shape of the quantized weight has one additional
485    # dimension compared to the shape of the fp32 weight. Match the shapes
486    # to enable cosine similarity comparison.
487    x = x.reshape(1, -1)
488    y = y.reshape(1, -1)
489    return torch.nn.functional.cosine_similarity(x, y)
490
491
492def op_type_supports_shadowing(node: Node) -> bool:
493    if node.op == "call_function":
494        if node.target in (
495            torch.add,
496            torch.mul,
497            operator.add,
498            operator.mul,
499            torch.cat,
500            torch.stack,
501        ):
502            # shadowing for ops with multiple tensor inputs is not implemented yet
503            return False
504    return True
505
506
507def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
508    """
509    Given a node, gets the n'th input to that node, normalizing
510    args and kwargs to the best of its ability.
511    """
512    try:
513        norm_args_and_kwargs = node.normalized_arguments(
514            gm, normalize_to_only_use_kwargs=True
515        )
516        if norm_args_and_kwargs is not None:
517            norm_args, norm_kwargs = norm_args_and_kwargs
518            assert len(norm_args) + len(norm_kwargs) > idx
519            if idx < len(norm_args):
520                return norm_args[idx]
521            else:
522                # note: in Python 3.7+ dicts are ordered
523                return list(norm_kwargs.values())[idx]
524        else:
525            assert len(node.args) + len(node.kwargs) > idx
526            if idx < len(node.args):
527                return node.args[idx]  # type: ignore[return-value]
528            else:
529                kwargs_idx = idx + len(node.args)
530                return list(node.kwargs.values())[kwargs_idx]  # type: ignore[return-value]
531    except RuntimeError:
532        # this RuntimeError happens when node argument normalization
533        # requires typehints to proceed, such as for torch.add where
534        # either the first, second or both arguments could be tensors
535        assert len(node.args) + len(node.kwargs) > idx
536        if idx < len(node.args):
537            return node.args[idx]  # type: ignore[return-value]
538        else:
539            kwargs_idx = idx + len(node.args)
540            return list(node.kwargs.values())[kwargs_idx]  # type: ignore[return-value]
541