1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass, PassResult 12 13 14def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph: 15 """ 16 Remove clone op nodes and replace uses with parent node. 17 """ 18 clone_op = exir_ops.edge.aten.clone.default 19 for node in graph.nodes: 20 if node.op == "call_function" and node.target == clone_op: 21 with graph.inserting_after(node): 22 node.replace_all_uses_with(node.args[0]) 23 24 graph.eliminate_dead_code() 25 return graph 26 27 28class RemoveCloneOpsTransform(ExportPass): 29 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 30 graph_module.graph = remove_clone_ops(graph_module.graph) 31 return PassResult(graph_module, True) 32