1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass, PassResult 12 13 14def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph: 15 """ 16 Find chains of view_copy nodes and merge them into one view_copy node. 17 Only merges view_copy nodes that are not used by any other nodes. 18 """ 19 ops = exir_ops.edge 20 view_op = ops.aten.view_copy.default 21 for node in graph.nodes: 22 if node.op == "call_function" and node.target == view_op: 23 # find ending view_copy node in chain 24 end_node = node 25 while ( 26 end_node.op == "call_function" 27 and end_node.target == view_op 28 and len(end_node.users) == 1 29 and list(end_node.users)[0].target == view_op 30 ): 31 end_node = list(end_node.users)[0] 32 # we can swap the first node's shape arg with the last node's shape arg 33 if node != end_node: 34 with graph.inserting_after(node): 35 new_args = (node.args[0], end_node.args[1]) 36 node.args = new_args 37 end_node.replace_all_uses_with(node) 38 39 graph.eliminate_dead_code() 40 return graph 41 42 43class FuseViewCopyTransform(ExportPass): 44 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 45 graph_module.graph = merge_view_copy_chains(graph_module.graph) 46 return PassResult(graph_module, True) 47