xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/fold_qdq.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.exir.dialects._ops import ops as exir_ops
8from executorch.exir.pass_base import ExportPass, PassResult
9from executorch.exir.passes import dead_code_elimination_pass
10
11from .utils import dq_ops, q_ops
12
13
14class FoldQDQ(ExportPass):
15    """
16    Erase QDQ pattern.
17    """
18
19    def __init__(self):
20        super(FoldQDQ, self).__init__()
21
22    def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
23        # remove dq
24        for n in graph_module.graph.nodes:
25            user_list = list(n.users.keys())
26            if n.target not in dq_ops:
27                continue
28            for user_n in user_list:
29                user_n.replace_input_with(n, n.args[0])
30            graph_module.graph.erase_node(n)
31
32        # remove q
33        for n in graph_module.graph.nodes:
34            if n.target not in q_ops:
35                continue
36            to_be_removed = [n]
37            source_n = n.args[0]
38
39            # TODO: remove this hack as source_fn_stack is internal implementation detail of torch.export.
40            # To make constant value/tensor be tagged as delegatable during partition
41            if source_n.op == "get_attr":
42                source_n.meta["source_fn_stack"] = list(n.users.keys())[0].meta.get(
43                    "source_fn_stack"
44                )
45
46            # collecting quant nodes to be removed
47            for i in range(1, len(n.args)):
48                if isinstance(n.args[i], torch.fx.node.Node):
49                    to_be_removed.append(n.args[i])
50                    # could be a commonly shared attribute between q & dq
51                    if n.args[i].target == exir_ops.edge.aten._to_copy.default:
52                        to_be_removed.append(n.args[i].args[0])
53            # connect source node to quant users and remove quant node
54            for user_n in list(n.users.keys()):
55                user_n.replace_input_with(n, n.args[0])
56            for n in to_be_removed:
57                graph_module.graph.erase_node(n)
58
59    def call(self, graph_module: torch.fx.GraphModule):
60        self._fold(graph_module)
61        graph_module.recompile()
62        dead_code_elimination_pass(graph_module)
63        return PassResult(graph_module, True)
64