xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/annotate_decomposed.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
8from executorch.exir.pass_base import ExportPass, PassResult
9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
10
11from .utils import dq_ops, get_quant_attrs, q_ops
12
13
14class AnnotateDecomposed(ExportPass):
15    """
16    Add "quant_attrs" to graph nodes' meta from the QDQ information
17    generated after quantization process.
18    """
19
20    def __init__(self, edge_program: torch.export.ExportedProgram):
21        super(AnnotateDecomposed, self).__init__()
22        self.edge_program = edge_program
23
24    def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
25        partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"])
26        for _, src_partitions in partitions.items():
27            for src_partition in src_partitions:
28                if src_partition.input_nodes[0].target in dq_ops:
29                    q_node = src_partition.input_nodes[0].args[0]
30                    quant_attrs = get_quant_attrs(self.edge_program, q_node)
31                    for n in src_partition.nodes:
32                        n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
33
34    def _annotate_stack(self, graph_module: torch.fx.GraphModule):
35        partitions = get_source_partitions(graph_module.graph, [torch.stack])
36        for _, src_partitions in partitions.items():
37            for src_partition in src_partitions:
38                output = src_partition.output_nodes[0]
39                if (list(output.users)[0].target) in q_ops:
40                    quant_attrs = get_quant_attrs(
41                        self.edge_program, list(output.users)[0]
42                    )
43                    for n in src_partition.nodes:
44                        n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
45
46    def call(self, graph_module: torch.fx.GraphModule):
47        self._annotate_unbind(graph_module)
48        self._annotate_stack(graph_module)
49        graph_module.recompile()
50        return PassResult(graph_module, True)
51