xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/duplicate_dq_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import operator
4
5import torch
6from torch.ao.quantization.pt2e.utils import (
7    _filter_sym_size_users,
8    _is_valid_annotation,
9)
10from torch.fx.node import map_arg
11from torch.fx.passes.infra.pass_base import PassBase, PassResult
12
13
14logger = logging.getLogger(__name__)
15logger.setLevel(logging.WARNING)
16
17__all__ = ["DuplicateDQPass"]
18
19_QUANTIZE_OPS = [
20    torch.ops.quantized_decomposed.quantize_per_tensor.default,
21    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
22    torch.ops.quantized_decomposed.quantize_per_channel.default,
23]
24
25_DEQUANTIZE_OPS = [
26    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
27    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
28    torch.ops.quantized_decomposed.dequantize_per_channel.default,
29]
30
31
32def _maybe_duplicate_dq(
33    gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
34):
35    annotation = user.meta.get("quantization_annotation", None)
36    if not _is_valid_annotation(annotation):
37        return
38    with gm.graph.inserting_after(dq_node):
39        new_node = gm.graph.node_copy(dq_node)
40
41        def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
42            if n == dq_node:
43                return new_node
44            else:
45                return n
46
47        new_args = map_arg(user.args, maybe_replace_node)
48        new_kwargs = map_arg(user.kwargs, maybe_replace_node)
49        user.args = new_args  # type: ignore[assignment]
50        user.kwargs = new_kwargs  # type: ignore[assignment]
51
52
53class DuplicateDQPass(PassBase):
54    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
55        for node in graph_module.graph.nodes:
56            if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
57                dq_users = _filter_sym_size_users(node)
58                if len(dq_users) <= 1:
59                    continue
60                # Do not duplicate dq for dynamic quantization
61                # Pattern: choose_qparam - getitem - q - dq
62                q_node = node.args[0]
63                if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
64                    getitem_node = q_node.args[1]
65                    if (
66                        isinstance(getitem_node, torch.fx.node.Node)
67                        and getitem_node.op == "call_function"
68                        and getitem_node.target == operator.getitem
69                    ):
70                        choose_qparam_node = getitem_node.args[0]
71                        if (
72                            isinstance(choose_qparam_node, torch.fx.node.Node)
73                            and choose_qparam_node.op == "call_function"
74                            and choose_qparam_node.target
75                            == torch.ops.quantized_decomposed.choose_qparams.tensor
76                        ):
77                            continue
78                for user in dq_users:
79                    _maybe_duplicate_dq(graph_module, node, user)
80        graph_module.graph.eliminate_dead_code()
81        graph_module.recompile()
82        return PassResult(graph_module, True)
83