xref: /aosp_15_r20/external/executorch/backends/transforms/fuse_view_copy.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10from executorch.exir.dialects._ops import ops as exir_ops
11from executorch.exir.pass_base import ExportPass, PassResult
12
13
14def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
15    """
16    Find chains of view_copy nodes and merge them into one view_copy node.
17    Only merges view_copy nodes that are not used by any other nodes.
18    """
19    ops = exir_ops.edge
20    view_op = ops.aten.view_copy.default
21    for node in graph.nodes:
22        if node.op == "call_function" and node.target == view_op:
23            # find ending view_copy node in chain
24            end_node = node
25            while (
26                end_node.op == "call_function"
27                and end_node.target == view_op
28                and len(end_node.users) == 1
29                and list(end_node.users)[0].target == view_op
30            ):
31                end_node = list(end_node.users)[0]
32            # we can swap the first node's shape arg with the last node's shape arg
33            if node != end_node:
34                with graph.inserting_after(node):
35                    new_args = (node.args[0], end_node.args[1])
36                    node.args = new_args
37                    end_node.replace_all_uses_with(node)
38
39    graph.eliminate_dead_code()
40    return graph
41
42
43class FuseViewCopyTransform(ExportPass):
44    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
45        graph_module.graph = merge_view_copy_chains(graph_module.graph)
46        return PassResult(graph_module, True)
47