1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch.fx 10from executorch.backends.arm._passes.arm_pass_utils import create_node 11from executorch.backends.arm.tosa_mapping import extract_tensor_meta 12from executorch.exir.dialects._ops import ops as exir_ops 13from executorch.exir.pass_base import ExportPass, PassResult 14 15 16class ConvertSplitToSlicePass(ExportPass): 17 """ 18 Replace a split operation with many slice operations. 19 """ 20 21 split_ops = ( 22 exir_ops.edge.aten.split_with_sizes_copy.default, 23 exir_ops.edge.aten.split_copy.Tensor, 24 ) 25 slice = exir_ops.edge.aten.slice_copy.Tensor 26 27 def call(self, graph_module: torch.fx.GraphModule): 28 graph = graph_module.graph 29 for node in graph.nodes: 30 if node.target not in self.split_ops: 31 continue 32 33 # Get useful variables 34 split_node = node 35 input_node = split_node.all_input_nodes[0] 36 output_nodes = split_node.users.copy() 37 _, shape, _ = extract_tensor_meta(input_node.meta) 38 rank = len(shape) 39 split_lengths = split_node.args[1] 40 dim = split_node.args[2] if len(split_node.args) > 2 else 0 41 dim = (dim + rank) % rank 42 43 assert ( 44 sum(split_lengths) == shape[dim] 45 ), "Given split lengths don't sum up to the size of the dimension." 46 47 # Convert split argument 'split_lengths' to slice arguments start and end. 48 starts = [0] * len(split_lengths) 49 ends = [0] * len(split_lengths) 50 start = 0 51 end = 0 52 for i, split_length in enumerate(split_lengths): 53 end = start + split_length 54 starts[i] = start 55 ends[i] = end 56 start = end 57 58 # Output nodes are of type getitem 59 # Replace them with one slice node for each output node. 60 with graph_module.graph.inserting_before(split_node): 61 for output_node in output_nodes: 62 index = output_node.args[1] 63 slice_node = create_node( 64 graph, 65 self.slice, 66 (input_node, dim, starts[index], ends[index]), 67 ) 68 slice_node.meta = split_node.meta.copy() 69 slice_node.meta["val"] = slice_node.meta["val"][index] 70 output_node.replace_all_uses_with(slice_node) 71 graph.eliminate_dead_code() 72 graph_module.recompile() 73 graph_module = super().call(graph_module).graph_module 74 return PassResult(graph_module, True) 75