1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import logging 10 11import torch 12 13from executorch.exir.dialects._ops import ops 14from torch.fx.passes.infra.pass_base import PassBase, PassResult 15 16 17def _is_view_copy(node: torch.fx.Node) -> bool: 18 return node.op == "call_function" and node.target in ( 19 torch.ops.aten.view_copy.default, 20 ops.edge.aten.view_copy.default, 21 ) 22 23 24class NormalizeViewCopyBasePass(PassBase): 25 """ 26 Point each view_copy to the first upstream non-view. 27 28 After this pass, the base of each view_copy is not a view_copy. 29 30 When combined with dead-code elimination, this pass removes redundant 31 view_copy nodes. 32 """ 33 34 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 35 n_updated = 0 36 for module in graph_module.modules(): 37 if not isinstance(module, torch.fx.GraphModule): 38 continue 39 for node in module.graph.nodes: 40 if _is_view_copy(node): 41 base, size = node.args 42 if _is_view_copy(base): 43 # Point base to bases's base and update node's args 44 # Base's base will not be a view_copy because we iterate 45 # through the graph in topological order, replacing as we go. 46 base = base.args[0] 47 node.args = (base, size) 48 n_updated += 1 49 50 module.recompile() 51 52 logging.debug(f"Updated the base on {n_updated} view_copy nodes.") 53 return PassResult(graph_module, n_updated > 0) 54 55 def ensures(self, graph_module: torch.fx.GraphModule) -> None: 56 for module in graph_module.modules(): 57 if not isinstance(module, torch.fx.GraphModule): 58 continue 59 for node in module.graph.nodes: 60 if _is_view_copy(node): 61 base, size = node.args 62 assert not _is_view_copy(base) 63