xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/reduce_dynamic_range.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import itertools
7import operator
8
9import torch
10from executorch.exir.pass_base import ExportPass, PassResult
11from executorch.exir.passes import dead_code_elimination_pass
12from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
13
14
15class ReduceDynamicRange(ExportPass):
16    """
17    Due to limitation in Qnn, we need to change torch.finfo(torch.float32).min
18    to the smallest representable value in quantization.
19    """
20
21    binary_op_sources = [
22        operator.add,
23        operator.sub,
24        operator.mul,
25        operator.truediv,
26        torch.add,
27        torch.sub,
28        torch.mul,
29        torch.div,
30        "add",
31        "sub",
32        "mul",
33        "truediv",
34    ]
35
36    def __init__(self):
37        super(ReduceDynamicRange, self).__init__()
38
39    def _traverse_binary_node(self, graph_module: torch.fx.GraphModule):
40        src_partitions = get_source_partitions(
41            graph_module.graph, self.binary_op_sources
42        )
43        src_partitions = list(itertools.chain(*src_partitions.values()))
44        for src_partition in src_partitions:
45            if len(src_partition.input_nodes) == 1:
46                binary_node = src_partition.nodes[0]
47                # (input node 0, constant value)
48                args_list = list(binary_node.args)
49                # Due to limitation in Qnn, we need to change torch.finfo(torch.float32).min
50                # to the smallest representable value in quantization
51                for i, arg in enumerate(args_list):
52                    if arg == torch.finfo(torch.float32).min:
53                        args_list[i] = -255.0
54                binary_node.args = tuple(args_list)
55
56    def call(self, graph_module: torch.fx.GraphModule):
57        self._traverse_binary_node(graph_module)
58        graph_module.recompile()
59        dead_code_elimination_pass(graph_module)
60        return PassResult(graph_module, True)
61