xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/recompose_pixel_unshuffle.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.exir.dialects._ops import ops as exir_ops
8from executorch.exir.pass_base import ExportPass, PassResult
9
10
11class RecomposePixelUnshuffle(ExportPass):
12    """
13    Merge decomposed operators from mathematically equivalent implementation
14    back to one super node.
15    """
16
17    def __init__(self, quantization_capture=False):
18        super(RecomposePixelUnshuffle, self).__init__()
19        self.reshape_target = exir_ops.edge.aten.view_copy.default
20        self.permute_target = exir_ops.edge.aten.permute_copy.default
21        self.view_target = exir_ops.edge.aten.view_copy.default
22        self.op = exir_ops.edge.aten.pixel_unshuffle.default
23
24        self.quantization_capture = quantization_capture
25        if quantization_capture:
26            self.reshape_target = torch.ops.aten._unsafe_view.default
27            self.permute_target = torch.ops.aten.permute.default
28            self.view_target = torch.ops.aten.view.default
29            self.op = torch.ops.aten.pixel_unshuffle.default
30
31    def call(self, graph_module: torch.fx.GraphModule):
32        graph = graph_module.graph
33        # math equivalent implementation
34        for node in graph.nodes:
35            if node.op == "call_function" and node.target == self.reshape_target:
36                with graph.inserting_after(node):
37
38                    # Clone op still exists between permute and reshape_target during quantization,
39                    # so we need to check for args[0].args[0] to get permute node
40                    if self.quantization_capture:
41                        premute_node = node.args[0].args[0]
42                    else:
43                        premute_node = node.args[0]
44                    if any(
45                        [
46                            len(node.args[1]) != 4,
47                            premute_node.op != "call_function",
48                            premute_node.target != self.permute_target,
49                        ]
50                    ):
51                        continue
52
53                    view_node = premute_node.args[0]
54                    if any(
55                        [
56                            view_node.op != "call_function",
57                            view_node.target != self.view_target,
58                            len(view_node.args[1]) != 6,
59                            len(premute_node.args[1]) != 6,
60                        ]
61                    ):
62                        continue
63
64                    b_in, d_nominal, h_in, s_h, w_in, s_w = view_node.args[1]
65                    b_out, d_out, w_out, h_out = node.args[1]
66                    if any(
67                        [
68                            b_out != b_in,
69                            d_out != d_nominal * s_h * s_w,
70                            w_in != w_out,
71                            h_in != h_out,
72                        ]
73                    ):
74                        continue
75
76                    input_node = view_node.args[0]
77                    args = (input_node, s_h)
78                    pixel_unshuffle_node = graph.create_node(
79                        "call_function", self.op, args
80                    )
81                    users = node.users.copy()
82                    for user in users:
83                        user.replace_input_with(node, pixel_unshuffle_node)
84                    # copy metadata
85                    pixel_unshuffle_node.meta = node.meta
86
87        graph.eliminate_dead_code()
88        graph_module.recompile()
89        return PassResult(graph_module, True)
90