xref: /aosp_15_r20/external/executorch/backends/transforms/duplicate_dynamic_quant_chain.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import operator
9
10import torch
11
12from torch.ao.quantization.pt2e.utils import (
13    _filter_sym_size_users,
14    _is_valid_annotation,
15)
16
17from torch.fx.node import map_arg
18from torch.fx.passes.infra.pass_base import PassBase, PassResult
19
20
21logger = logging.getLogger(__name__)
22logger.setLevel(logging.WARNING)
23
24__all__ = ["DuplicateDynamicQuantChainPass"]
25
26_QUANTIZE_OPS = [
27    torch.ops.quantized_decomposed.quantize_per_tensor.default,
28    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
29    torch.ops.quantized_decomposed.quantize_per_channel.default,
30]
31
32_DEQUANTIZE_OPS = [
33    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
34    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
35    torch.ops.quantized_decomposed.dequantize_per_channel.default,
36]
37
38
39def _replace_input_node_with_new_node(node, input_node, new_node):
40    def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
41        if n == input_node:
42            return new_node
43        else:
44            return n
45
46    new_args = map_arg(node.args, maybe_replace_node)
47    new_kwargs = map_arg(node.kwargs, maybe_replace_node)
48    node.args = new_args
49    node.kwargs = new_kwargs
50
51
52def _replicate_chose_qparam_nodes_for_q_dq(
53    gm: torch.fx.GraphModule, chose_qparams_node, get_item_node_1, get_item_node_2
54):
55    if (
56        (
57            chose_qparams_node.target
58            != torch.ops.quantized_decomposed.choose_qparams.tensor
59        )
60        or (get_item_node_1.target != operator.getitem)
61        or (get_item_node_2.target != operator.getitem)
62    ):
63        raise RuntimeError(
64            f"Expecting choose_qparams.tensor and getitem nodes but got {chose_qparams_node}, {get_item_node_1}, {get_item_node_2}"
65        )
66
67    users = list(get_item_node_1.users.copy())
68    q_dq_pair = []
69    for user in users:
70        if user.target in _QUANTIZE_OPS:
71            if len(user.users) != 1:
72                raise RuntimeError(f"Node {user} has more than one user")
73            dq_node = list(user.users)[0]
74            if dq_node.target not in _DEQUANTIZE_OPS:
75                raise RuntimeError(
76                    f"Node {user}'s use must be a dequantize op but got {dq_node}:{dq_node.target}"
77                )
78            q_dq_pair.append((user, dq_node))
79
80    for q_node, dq_node in q_dq_pair:
81        with gm.graph.inserting_after(get_item_node_1):
82            new_get_item_node_1 = gm.graph.node_copy(get_item_node_1)
83            new_get_item_node_2 = gm.graph.node_copy(get_item_node_2)
84            new_chose_qparams_node = gm.graph.node_copy(chose_qparams_node)
85            _replace_input_node_with_new_node(
86                new_get_item_node_1, chose_qparams_node, new_chose_qparams_node
87            )
88            _replace_input_node_with_new_node(
89                new_get_item_node_2, chose_qparams_node, new_chose_qparams_node
90            )
91
92            _replace_input_node_with_new_node(
93                q_node, get_item_node_1, new_get_item_node_1
94            )
95            _replace_input_node_with_new_node(
96                dq_node, get_item_node_1, new_get_item_node_1
97            )
98            _replace_input_node_with_new_node(
99                q_node, get_item_node_2, new_get_item_node_2
100            )
101            _replace_input_node_with_new_node(
102                dq_node, get_item_node_2, new_get_item_node_2
103            )
104
105    gm.graph.eliminate_dead_code()
106    gm.recompile()
107
108
109def _replicate_node_for_each_user(gm: torch.fx.GraphModule, node: torch.fx.Node):
110    users = list(node.users.copy())
111    for user in users:
112        with gm.graph.inserting_after(node):
113            new_node = gm.graph.node_copy(node)
114            _replace_input_node_with_new_node(user, node, new_node)
115
116    gm.graph.eliminate_dead_code()
117    gm.recompile()
118
119
120def _maybe_duplicate_dynamic_quantize_chain(
121    gm: torch.fx.GraphModule,
122    chose_qparams_node,
123    get_item_node_1,
124    get_item_node_2,
125    q_node,
126    dq_node: torch.fx.Node,
127):
128    num_dq_users = len(dq_node.users)
129    dq_node_users = list(dq_node.users.copy())
130    for user in dq_node_users:
131        annotation = user.meta.get("quantization_annotation", None)
132        if not _is_valid_annotation(annotation):
133            return
134        with gm.graph.inserting_after(dq_node):
135            new_node = gm.graph.node_copy(dq_node)
136            _replace_input_node_with_new_node(user, dq_node, new_node)
137
138    gm.graph.eliminate_dead_code()
139    gm.recompile()
140    if len(q_node.users) != num_dq_users:
141        raise RuntimeError(
142            f"Expected {num_dq_users} users of {q_node}, but got {len(q_node.users)}"
143        )
144    _replicate_node_for_each_user(gm, q_node)
145
146    # *2 because scale/zp are used both with q and dq nodes
147    if len(get_item_node_1.users) != num_dq_users * 2:
148        raise RuntimeError(
149            f"Expected {num_dq_users} users of {get_item_node_1}, but got {len(get_item_node_1.users)}"
150        )
151    if len(get_item_node_2.users) != num_dq_users * 2:
152        raise RuntimeError(
153            f"Expected {num_dq_users} users of {get_item_node_2}, but got {len(get_item_node_2.users)}"
154        )
155    _replicate_chose_qparam_nodes_for_q_dq(
156        gm, chose_qparams_node, get_item_node_1, get_item_node_2
157    )
158
159
160class DuplicateDynamicQuantChainPass(PassBase):
161    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
162        for node in graph_module.graph.nodes:
163            if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
164                dq_users = _filter_sym_size_users(node)
165                if len(dq_users) <= 1:
166                    continue
167                # Do not duplicate dq for dynamic quantization
168                # Pattern: choose_qparam - getitem - q - dq
169                q_node = node.args[0]
170                if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
171                    getitem_1_node = q_node.args[1]
172                    getitem_2_node = q_node.args[2]
173                    if (
174                        isinstance(getitem_1_node, torch.fx.node.Node)
175                        and getitem_1_node.op == "call_function"
176                        and getitem_1_node.target == operator.getitem
177                    ):
178                        choose_qparam_node = getitem_1_node.args[0]
179                        if (
180                            isinstance(choose_qparam_node, torch.fx.node.Node)
181                            and choose_qparam_node.op == "call_function"
182                            and choose_qparam_node.target
183                            == torch.ops.quantized_decomposed.choose_qparams.tensor
184                        ):
185                            _maybe_duplicate_dynamic_quantize_chain(
186                                graph_module,
187                                choose_qparam_node,
188                                getitem_1_node,
189                                getitem_2_node,
190                                q_node,
191                                node,
192                            )
193                            continue
194        graph_module.graph.eliminate_dead_code()
195        graph_module.recompile()
196        return PassResult(graph_module, True)
197