xref: /aosp_15_r20/external/executorch/backends/arm/_passes/convert_split_to_slice.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch.fx
10from executorch.backends.arm._passes.arm_pass_utils import create_node
11from executorch.backends.arm.tosa_mapping import extract_tensor_meta
12from executorch.exir.dialects._ops import ops as exir_ops
13from executorch.exir.pass_base import ExportPass, PassResult
14
15
16class ConvertSplitToSlicePass(ExportPass):
17    """
18    Replace a split operation with many slice operations.
19    """
20
21    split_ops = (
22        exir_ops.edge.aten.split_with_sizes_copy.default,
23        exir_ops.edge.aten.split_copy.Tensor,
24    )
25    slice = exir_ops.edge.aten.slice_copy.Tensor
26
27    def call(self, graph_module: torch.fx.GraphModule):
28        graph = graph_module.graph
29        for node in graph.nodes:
30            if node.target not in self.split_ops:
31                continue
32
33            # Get useful variables
34            split_node = node
35            input_node = split_node.all_input_nodes[0]
36            output_nodes = split_node.users.copy()
37            _, shape, _ = extract_tensor_meta(input_node.meta)
38            rank = len(shape)
39            split_lengths = split_node.args[1]
40            dim = split_node.args[2] if len(split_node.args) > 2 else 0
41            dim = (dim + rank) % rank
42
43            assert (
44                sum(split_lengths) == shape[dim]
45            ), "Given split lengths don't sum up to the size of the dimension."
46
47            # Convert split argument 'split_lengths' to slice arguments start and end.
48            starts = [0] * len(split_lengths)
49            ends = [0] * len(split_lengths)
50            start = 0
51            end = 0
52            for i, split_length in enumerate(split_lengths):
53                end = start + split_length
54                starts[i] = start
55                ends[i] = end
56                start = end
57
58            # Output nodes are of type getitem
59            # Replace them with one slice node for each output node.
60            with graph_module.graph.inserting_before(split_node):
61                for output_node in output_nodes:
62                    index = output_node.args[1]
63                    slice_node = create_node(
64                        graph,
65                        self.slice,
66                        (input_node, dim, starts[index], ends[index]),
67                    )
68                    slice_node.meta = split_node.meta.copy()
69                    slice_node.meta["val"] = slice_node.meta["val"][index]
70                    output_node.replace_all_uses_with(slice_node)
71        graph.eliminate_dead_code()
72        graph_module.recompile()
73        graph_module = super().call(graph_module).graph_module
74        return PassResult(graph_module, True)
75