xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/convert_interpolate_with_upsample2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.exir.dialects._ops import ops as exir_ops
8from executorch.exir.pass_base import ExportPass, PassResult
9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
10
11
12class ConvertInterpolateWithUpsample2D(ExportPass):
13    """
14    Merge decomposed operators from interpolate back to one super node.
15    TODO: Currently we only map to upsample2d version, should extend the
16    capability by reverse engineering the decomposition process.
17    """
18
19    def __init__(self):
20        super(ConvertInterpolateWithUpsample2D, self).__init__()
21
22    def call(self, graph_module: torch.fx.GraphModule):
23        graph = graph_module.graph
24        partitions = get_source_partitions(graph, [torch.nn.functional.interpolate])
25        for _, src_partitions in partitions.items():
26            for src_partition in src_partitions:
27                input_node = src_partition.input_nodes[0]
28                output_node = src_partition.output_nodes[0]
29                with graph.inserting_after(input_node):
30                    # TODO: robust way to get the configuration parameters and operator
31                    # please check torch/_decomp/decomposition.py for details
32                    if output_node.target.__name__ == "aten.index.Tensor":
33                        # nearest_2d
34                        # args: input, output_size, scales_h, scales_w
35                        output_size = list(output_node.meta["val"].shape)
36                        args = [input_node, output_size[-2:]]
37                        upsample_op = exir_ops.edge.aten.upsample_nearest2d.default
38                    else:
39                        # upsample_2d
40                        # args: input, output_size, aligned_corners, scales_h, scales_w
41                        output_size = list(output_node.meta["val"].shape)
42                        args = [input_node, output_size[-2:], False]
43                        upsample_op = exir_ops.edge.aten.upsample_bilinear2d.default
44
45                    upsample2d_node = graph.create_node(
46                        "call_function", upsample_op, tuple(args)
47                    )
48                    users = output_node.users.copy()
49                    for user in users:
50                        user.replace_input_with(output_node, upsample2d_node)
51                    # copy metadata
52                    upsample2d_node.meta = output_node.meta
53
54        graph.eliminate_dead_code()
55        graph_module.recompile()
56        return PassResult(graph_module, True)
57