xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/insert_requantize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8
9from executorch.backends.qualcomm.utils.constants import (
10    QCOM_QUANT_ATTRS,
11    QCOM_QUANTIZED_IO,
12    QCOM_REQUANTIZE,
13)
14
15from executorch.exir.dialects._ops import ops as exir_ops
16from executorch.exir.pass_base import ExportPass, PassResult
17
18
19class InsertRequantize(ExportPass):
20    """
21    This pass inserts convert op for operators which have
22    different quantization specs in input and activation.
23    Convert OP is a specific op which helps to requantize in Qnn backend
24    """
25
26    # Storing ops that has multi output but run _single_output_annotation logic
27    # instead of _multi_output_annotation. Ops might be added into this set because
28    # we don't use the 2nd output, 2nd output is an integer, etc.
29    multi_output_op_ignore_set = {
30        exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
31        exir_ops.edge.aten.topk.default,
32    }
33
34    def __init__(
35        self,
36        edge_program: torch.export.ExportedProgram,
37    ):
38        super(InsertRequantize, self).__init__()
39        self.edge_program = edge_program
40
41    # TODO: Implement this function when we have an op with
42    # multiple outputs that requires quant attributes.
43    def _multi_output_annotation(self) -> None:
44        raise NotImplementedError("requant is not implemented for multi output yet")
45
46    def _single_output_annotation(
47        self, gm: torch.fx.GraphModule, n: torch.fx.node
48    ) -> None:
49        with gm.graph.inserting_after(n):
50            users = list(n.users.keys())
51            inserted_n = gm.graph.create_node(
52                "call_function",
53                exir_ops.edge.aten._to_copy.default,
54                (n,),
55            )
56
57            inserted_n.meta["val"] = n.meta["val"]
58            inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE)
59            if n.meta.get(QCOM_QUANTIZED_IO):
60                inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO]
61
62            for user in users:
63                user.replace_input_with(n, inserted_n)
64
65    def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
66        for n in graph_module.graph.nodes:
67            if QCOM_REQUANTIZE in n.meta:
68                (
69                    self._single_output_annotation(graph_module, n)
70                    if isinstance(
71                        n.meta["val"], torch._subclasses.fake_tensor.FakeTensor
72                    )
73                    or n.target in self.multi_output_op_ignore_set
74                    else self._multi_output_annotation()
75                )
76
77    def call(self, graph_module: torch.fx.GraphModule):
78        self._insert(graph_module)
79        graph_module.graph.eliminate_dead_code()
80        graph_module.recompile()
81        return PassResult(graph_module, True)
82