xref: /aosp_15_r20/external/executorch/exir/tests/test_remove_view_copy.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import unittest
9
10import torch
11import torch.nn as nn
12from executorch.exir import memory, to_edge
13from executorch.exir.capture._config import ExecutorchBackendConfig
14from executorch.exir.passes import MemoryPlanningPass
15
16
17class TestModel1(nn.Module):
18    def __init__(self):
19        super().__init__()
20        self.parameter = nn.Parameter(torch.rand(5, 6))
21        self.parameter.requires_grad = False
22        self.parameter2 = nn.Parameter(torch.rand(30))
23        self.parameter2.requires_grad = False
24
25    def forward(self, x):
26        v1 = self.parameter.view(
27            6, 5
28        )  # removed, lifetime of parameter will be extended
29        v2 = x.view(6, 5)  # not removed
30        v3 = torch.ops.aten.mul.Tensor(v1, v2).view(
31            30
32        )  # removed, lifetime of mul.Tensor will be extended
33        v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2)
34        v5 = v4.view(6, 5)  # not removed, output of the graph
35        v6 = v4.view(2, 15)  # not removed, output of the graph
36        return v5, v6
37
38    def get_example_inputs(self):
39        return (torch.rand(5, 6),)
40
41
42class TestRemoveViewCopy(unittest.TestCase):
43    def test_disable(self) -> None:
44        model = TestModel1()
45        model.eval()
46        example_inputs = model.get_example_inputs()
47        ep = torch.export.export(model, example_inputs)
48        etpm = to_edge(ep).to_executorch(
49            config=ExecutorchBackendConfig(
50                remove_view_copy=False,
51                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
52            ),
53        )
54
55        for node in etpm.exported_program().graph_module.graph.nodes:
56            assert node.target != memory.view
57
58    def test_output_matches(self) -> None:
59        model = TestModel1()
60        model.eval()
61        example_inputs = model.get_example_inputs()
62        ep = torch.export.export(model, example_inputs)
63
64        epm_remove = to_edge(ep)
65        epm_no_remove = copy.deepcopy(
66            epm_remove
67        )  # to_executorch modifies the edge_program, so we make a copy
68
69        # Run pass with no removal
70        etpm_remove = epm_remove.to_executorch(
71            config=ExecutorchBackendConfig(
72                remove_view_copy=True,
73                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
74            ),
75        )
76
77        # Run pass with removal
78        etpm_no_remove = epm_no_remove.to_executorch(
79            config=ExecutorchBackendConfig(
80                remove_view_copy=True,
81                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
82            ),
83        )
84
85        out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()(
86            *example_inputs
87        )
88        out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()(
89            *example_inputs
90        )
91
92        self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5))
93        self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6))
94
95    def test_spec(self) -> None:
96        model = TestModel1()
97        model.eval()
98        example_inputs = model.get_example_inputs()
99        ep = torch.export.export(model, example_inputs)
100
101        etpm = to_edge(ep).to_executorch(
102            config=ExecutorchBackendConfig(
103                remove_view_copy=True,
104                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
105            ),
106        )
107
108        # etpm.exported_program().graph.print_tabular()
109
110        # idx  opcode         name                      target                              args                                                kwargs
111        # ---  -------------  ------------------------  ----------------------------------  --------------------------------------------------  ----------------
112        # 0    placeholder    p_parameter               p_parameter                         ()                                                  {}
113        # 1    placeholder    p_parameter2              p_parameter2                        ()                                                  {}
114        # 2    placeholder    x                         x                                   ()                                                  {}
115        # 3    call_function  aten_view_copy_default    <function view at 0x7fe57bea6d40>   (p_parameter, [6, 5])                               {}
116        # 4    call_function  aten_view_copy_default_1  <function view at 0x7fe57bea6d40>   (x, [6, 5])                                         {}
117        # 5    call_function  alloc                     <function alloc at 0x7fe57bea6c20>  (((6, 5), torch.float32),)                          {}
118        # 6    call_function  aten_mul_tensor           aten.mul.out                        (aten_view_copy_default, aten_view_copy_default_1)  {'out': alloc}
119        # 7    call_function  aten_view_copy_default_2  <function view at 0x7fe57bea6d40>   (aten_mul_tensor, [30])                             {}
120        # 8    call_function  alloc_1                   <function alloc at 0x7fe57bea6c20>  (((30,), torch.float32),)                           {}
121        # 9    call_function  aten_mul_tensor_1         aten.mul.out                        (aten_view_copy_default_2, p_parameter2)            {'out': alloc_1}
122        # 10   call_function  alloc_2                   <function alloc at 0x7fe57bea6c20>  (((6, 5), torch.float32),)                          {}
123        # 11   call_function  aten_view_copy_default_3  aten.view_copy.out                  (aten_mul_tensor_1, [6, 5])                         {'out': alloc_2}
124        # 12   output         output_1                  output                              ((aten_view_copy_default_3,),)                      {}
125
126        for node in etpm.exported_program().graph.nodes:
127            if node.name == "p_parameter":
128                # p_parameter's lifetime is extended through aten_view_copy_default (memory.view) to idx 6
129                self.assertEqual(node.meta["spec"].lifetime, [0, 6])
130            elif node.name == "aten_view_copy_default":
131                # aten_view_copy_default is a memory.view of p_parameter.
132                # p_parameter is a constant with storage, so we check that the view's storage matches the base
133
134                # assert base is p_parameter
135                self.assertEqual(node.args[0].name, "p_parameter")
136
137                # assert base is const with storage
138                self.assertTrue(node.args[0].meta["spec"].const)
139                self.assertTrue(node.args[0].meta["spec"].storage is not None)
140                self.assertTrue(node.args[0].meta["spec"].mem_id is None)
141                self.assertTrue(node.args[0].meta["spec"].mem_offset is None)
142
143                # assert self is const with storage
144                self.assertTrue(node.meta["spec"].const)
145                self.assertTrue(node.meta["spec"].storage is not None)
146                self.assertTrue(node.meta["spec"].mem_id is None)
147                self.assertTrue(node.meta["spec"].mem_offset is None)
148
149                # assert storage matches
150                self.assertEqual(
151                    node.meta["spec"].storage, node.args[0].meta["spec"].storage
152                )
153
154                # assert lifetime matches
155                self.assertEqual(
156                    node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
157                )
158            elif node.name == "aten_mul_tensor":
159                # aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 9
160                self.assertEqual(node.meta["spec"].lifetime, [5, 9])
161            elif node.name == "aten_view_copy_default_2":
162                # aten_view_copy_default_2 is a memory.view of aten_mul_tensor
163
164                # assert base is aten_mul_tensor
165                self.assertEqual(node.args[0].name, "aten_mul_tensor")
166
167                # assert base and self are not const, do not have storage,
168                # but do have mem_id and mem_offset
169                self.assertFalse(node.args[0].meta["spec"].const)
170                self.assertTrue(node.args[0].meta["spec"].storage is None)
171                self.assertTrue(node.args[0].meta["spec"].mem_id is not None)
172                self.assertTrue(node.args[0].meta["spec"].mem_offset is not None)
173
174                self.assertFalse(node.meta["spec"].const)
175                self.assertTrue(node.meta["spec"].storage is None)
176                self.assertTrue(node.meta["spec"].mem_id is not None)
177                self.assertTrue(node.meta["spec"].mem_offset is not None)
178
179                # assert self and base mem_id, mem_offset, and lifetime matches
180                self.assertEqual(
181                    node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id
182                )
183                self.assertEqual(
184                    node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset
185                )
186                self.assertEqual(
187                    node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
188                )
189
190        # Test evalues in execution plan
191        plan = etpm.executorch_program.execution_plan[0]
192        self.assertEqual(plan.operators[0].name, "executorch_prim::et_view")
193        self.assertEqual(plan.operators[1].name, "aten::mul")
194        self.assertEqual(plan.operators[2].name, "aten::view_copy")
195
196        instructions = plan.chains[0].instructions
197        self.assertEqual(len(instructions), 7)
198
199        self.assertEqual(
200            instructions[0].instr_args.op_index, 0  # pyre-ignore
201        )  # view @ idx2
202        self.assertEqual(
203            instructions[1].instr_args.op_index, 0  # pyre-ignore
204        )  # view @ idx3
205        self.assertEqual(
206            instructions[2].instr_args.op_index, 1  # pyre-ignore
207        )  # aten:mul @ idx6
208        self.assertEqual(
209            instructions[3].instr_args.op_index, 0  # pyre-ignore
210        )  # view @ idx7
211        self.assertEqual(
212            instructions[4].instr_args.op_index, 1  # pyre-ignore
213        )  # aten:mul @ idx9
214        self.assertEqual(
215            instructions[5].instr_args.op_index, 2  # pyre-ignore
216        )  # aten:view_copy @ idx11
217        self.assertEqual(
218            instructions[6].instr_args.op_index, 2  # pyre-ignore
219        )  # aten:view_copy @ idx11
220