xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/build_quant_io.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
8
9from executorch.exir.pass_base import ExportPass, PassResult
10from executorch.exir.tensor import TensorSpec
11
12
13class BuildQuantIo(ExportPass):
14    """
15    To make lowering process correct, the pass assign the correct quantized dtype to spec of call_delegate.
16    """
17
18    def __init__(self):
19        super(BuildQuantIo, self).__init__()
20
21    def _make_spec(self, x):
22        if isinstance(x, torch.Tensor):
23            return TensorSpec.from_tensor(x)
24        elif isinstance(x, (int, bool, float)):
25            return x
26        else:
27            return None
28
29    def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
30        # forcely update delegate node's meta['spec'] to get correct output
31        # tensor size in runtime
32        call_delegate = [
33            node
34            for node in graph_module.graph.nodes
35            if node.op == "call_function" and node.name == "executorch_call_delegate"
36        ]
37        assert len(call_delegate) == 1
38        spec = []
39        for n in graph_module.graph.nodes:
40            if QCOM_QUANTIZED_IO in n.meta:
41                n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])
42            if n.op == "call_function" and "getitem" in n.name:
43                fake_tensor = n.meta["val"]
44                if QCOM_QUANTIZED_IO in n.meta:
45                    fake_tensor = fake_tensor.to(dtype=n.meta[QCOM_QUANTIZED_IO])
46                spec.append(self._make_spec(fake_tensor))
47
48        call_delegate[0].meta["spec"] = tuple(spec)
49
50    def call(self, graph_module: torch.fx.GraphModule):
51        self._build(graph_module)
52        graph_module.graph.eliminate_dead_code()
53        graph_module.recompile()
54        return PassResult(graph_module, True)
55