xref: /aosp_15_r20/external/executorch/exir/passes/normalize_view_copy_base_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import logging
10
11import torch
12
13from executorch.exir.dialects._ops import ops
14from torch.fx.passes.infra.pass_base import PassBase, PassResult
15
16
17def _is_view_copy(node: torch.fx.Node) -> bool:
18    return node.op == "call_function" and node.target in (
19        torch.ops.aten.view_copy.default,
20        ops.edge.aten.view_copy.default,
21    )
22
23
24class NormalizeViewCopyBasePass(PassBase):
25    """
26    Point each view_copy to the first upstream non-view.
27
28    After this pass, the base of each view_copy is not a view_copy.
29
30    When combined with dead-code elimination, this pass removes redundant
31    view_copy nodes.
32    """
33
34    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
35        n_updated = 0
36        for module in graph_module.modules():
37            if not isinstance(module, torch.fx.GraphModule):
38                continue
39            for node in module.graph.nodes:
40                if _is_view_copy(node):
41                    base, size = node.args
42                    if _is_view_copy(base):
43                        # Point base to bases's base and update node's args
44                        # Base's base will not be a view_copy because we iterate
45                        # through the graph in topological order, replacing as we go.
46                        base = base.args[0]
47                        node.args = (base, size)
48                        n_updated += 1
49
50            module.recompile()
51
52        logging.debug(f"Updated the base on {n_updated} view_copy nodes.")
53        return PassResult(graph_module, n_updated > 0)
54
55    def ensures(self, graph_module: torch.fx.GraphModule) -> None:
56        for module in graph_module.modules():
57            if not isinstance(module, torch.fx.GraphModule):
58                continue
59            for node in module.graph.nodes:
60                if _is_view_copy(node):
61                    base, size = node.args
62                    assert not _is_view_copy(base)
63