xref: /aosp_15_r20/external/executorch/backends/arm/_passes/tag_io_quant_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch
10from executorch.exir.dialects._ops import ops as exir_ops
11from executorch.exir.pass_base import ExportPass, PassResult
12
13
14class TagIOQuantPass(ExportPass):
15    """
16    Pass run before partitioning to tag Q/DQ on any placeholder and output
17    to ensure we don't greedily partition them for device. Float conversion
18    has to happen outside a TOSA base inference profile.
19    """
20
21    def is_quant_node(self, node: torch.fx.node.Node):
22        return node.target in {
23            exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
24            exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
25            exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
26        }
27
28    def is_dequant_node(self, node: torch.fx.node.Node):
29        return node.target in {
30            exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
31            exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
32            exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
33        }
34
35    def call(self, graph_module: torch.fx.GraphModule):
36        for node in graph_module.graph.nodes:
37            # tag q of input
38            if node.op == "placeholder":
39                for user in node.users.keys():
40                    # if we have an input going into a quantize
41                    if self.is_quant_node(user):
42                        user.meta["arm_override_partition"] = False
43
44            # tag dq of outputs
45            if node.op == "output":
46                for quant in node.args[0]:
47                    if self.is_dequant_node(quant):
48                        quant.meta["arm_override_partition"] = False
49
50        graph_module.recompile()
51        return PassResult(graph_module, True)
52