1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import unittest 9 10import torch 11import torch.nn as nn 12from executorch.exir import memory, to_edge 13from executorch.exir.capture._config import ExecutorchBackendConfig 14from executorch.exir.passes import MemoryPlanningPass 15 16 17class TestModel1(nn.Module): 18 def __init__(self): 19 super().__init__() 20 self.parameter = nn.Parameter(torch.rand(5, 6)) 21 self.parameter.requires_grad = False 22 self.parameter2 = nn.Parameter(torch.rand(30)) 23 self.parameter2.requires_grad = False 24 25 def forward(self, x): 26 v1 = self.parameter.view( 27 6, 5 28 ) # removed, lifetime of parameter will be extended 29 v2 = x.view(6, 5) # not removed 30 v3 = torch.ops.aten.mul.Tensor(v1, v2).view( 31 30 32 ) # removed, lifetime of mul.Tensor will be extended 33 v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2) 34 v5 = v4.view(6, 5) # not removed, output of the graph 35 v6 = v4.view(2, 15) # not removed, output of the graph 36 return v5, v6 37 38 def get_example_inputs(self): 39 return (torch.rand(5, 6),) 40 41 42class TestRemoveViewCopy(unittest.TestCase): 43 def test_disable(self) -> None: 44 model = TestModel1() 45 model.eval() 46 example_inputs = model.get_example_inputs() 47 ep = torch.export.export(model, example_inputs) 48 etpm = to_edge(ep).to_executorch( 49 config=ExecutorchBackendConfig( 50 remove_view_copy=False, 51 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 52 ), 53 ) 54 55 for node in etpm.exported_program().graph_module.graph.nodes: 56 assert node.target != memory.view 57 58 def test_output_matches(self) -> None: 59 model = TestModel1() 60 model.eval() 61 example_inputs = model.get_example_inputs() 62 ep = torch.export.export(model, example_inputs) 63 64 epm_remove = to_edge(ep) 65 epm_no_remove = copy.deepcopy( 66 epm_remove 67 ) # to_executorch modifies the edge_program, so we make a copy 68 69 # Run pass with no removal 70 etpm_remove = epm_remove.to_executorch( 71 config=ExecutorchBackendConfig( 72 remove_view_copy=True, 73 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 74 ), 75 ) 76 77 # Run pass with removal 78 etpm_no_remove = epm_no_remove.to_executorch( 79 config=ExecutorchBackendConfig( 80 remove_view_copy=True, 81 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 82 ), 83 ) 84 85 out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()( 86 *example_inputs 87 ) 88 out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()( 89 *example_inputs 90 ) 91 92 self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5)) 93 self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6)) 94 95 def test_spec(self) -> None: 96 model = TestModel1() 97 model.eval() 98 example_inputs = model.get_example_inputs() 99 ep = torch.export.export(model, example_inputs) 100 101 etpm = to_edge(ep).to_executorch( 102 config=ExecutorchBackendConfig( 103 remove_view_copy=True, 104 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 105 ), 106 ) 107 108 # etpm.exported_program().graph.print_tabular() 109 110 # idx opcode name target args kwargs 111 # --- ------------- ------------------------ ---------------------------------- -------------------------------------------------- ---------------- 112 # 0 placeholder p_parameter p_parameter () {} 113 # 1 placeholder p_parameter2 p_parameter2 () {} 114 # 2 placeholder x x () {} 115 # 3 call_function aten_view_copy_default <function view at 0x7fe57bea6d40> (p_parameter, [6, 5]) {} 116 # 4 call_function aten_view_copy_default_1 <function view at 0x7fe57bea6d40> (x, [6, 5]) {} 117 # 5 call_function alloc <function alloc at 0x7fe57bea6c20> (((6, 5), torch.float32),) {} 118 # 6 call_function aten_mul_tensor aten.mul.out (aten_view_copy_default, aten_view_copy_default_1) {'out': alloc} 119 # 7 call_function aten_view_copy_default_2 <function view at 0x7fe57bea6d40> (aten_mul_tensor, [30]) {} 120 # 8 call_function alloc_1 <function alloc at 0x7fe57bea6c20> (((30,), torch.float32),) {} 121 # 9 call_function aten_mul_tensor_1 aten.mul.out (aten_view_copy_default_2, p_parameter2) {'out': alloc_1} 122 # 10 call_function alloc_2 <function alloc at 0x7fe57bea6c20> (((6, 5), torch.float32),) {} 123 # 11 call_function aten_view_copy_default_3 aten.view_copy.out (aten_mul_tensor_1, [6, 5]) {'out': alloc_2} 124 # 12 output output_1 output ((aten_view_copy_default_3,),) {} 125 126 for node in etpm.exported_program().graph.nodes: 127 if node.name == "p_parameter": 128 # p_parameter's lifetime is extended through aten_view_copy_default (memory.view) to idx 6 129 self.assertEqual(node.meta["spec"].lifetime, [0, 6]) 130 elif node.name == "aten_view_copy_default": 131 # aten_view_copy_default is a memory.view of p_parameter. 132 # p_parameter is a constant with storage, so we check that the view's storage matches the base 133 134 # assert base is p_parameter 135 self.assertEqual(node.args[0].name, "p_parameter") 136 137 # assert base is const with storage 138 self.assertTrue(node.args[0].meta["spec"].const) 139 self.assertTrue(node.args[0].meta["spec"].storage is not None) 140 self.assertTrue(node.args[0].meta["spec"].mem_id is None) 141 self.assertTrue(node.args[0].meta["spec"].mem_offset is None) 142 143 # assert self is const with storage 144 self.assertTrue(node.meta["spec"].const) 145 self.assertTrue(node.meta["spec"].storage is not None) 146 self.assertTrue(node.meta["spec"].mem_id is None) 147 self.assertTrue(node.meta["spec"].mem_offset is None) 148 149 # assert storage matches 150 self.assertEqual( 151 node.meta["spec"].storage, node.args[0].meta["spec"].storage 152 ) 153 154 # assert lifetime matches 155 self.assertEqual( 156 node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime 157 ) 158 elif node.name == "aten_mul_tensor": 159 # aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 9 160 self.assertEqual(node.meta["spec"].lifetime, [5, 9]) 161 elif node.name == "aten_view_copy_default_2": 162 # aten_view_copy_default_2 is a memory.view of aten_mul_tensor 163 164 # assert base is aten_mul_tensor 165 self.assertEqual(node.args[0].name, "aten_mul_tensor") 166 167 # assert base and self are not const, do not have storage, 168 # but do have mem_id and mem_offset 169 self.assertFalse(node.args[0].meta["spec"].const) 170 self.assertTrue(node.args[0].meta["spec"].storage is None) 171 self.assertTrue(node.args[0].meta["spec"].mem_id is not None) 172 self.assertTrue(node.args[0].meta["spec"].mem_offset is not None) 173 174 self.assertFalse(node.meta["spec"].const) 175 self.assertTrue(node.meta["spec"].storage is None) 176 self.assertTrue(node.meta["spec"].mem_id is not None) 177 self.assertTrue(node.meta["spec"].mem_offset is not None) 178 179 # assert self and base mem_id, mem_offset, and lifetime matches 180 self.assertEqual( 181 node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id 182 ) 183 self.assertEqual( 184 node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset 185 ) 186 self.assertEqual( 187 node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime 188 ) 189 190 # Test evalues in execution plan 191 plan = etpm.executorch_program.execution_plan[0] 192 self.assertEqual(plan.operators[0].name, "executorch_prim::et_view") 193 self.assertEqual(plan.operators[1].name, "aten::mul") 194 self.assertEqual(plan.operators[2].name, "aten::view_copy") 195 196 instructions = plan.chains[0].instructions 197 self.assertEqual(len(instructions), 7) 198 199 self.assertEqual( 200 instructions[0].instr_args.op_index, 0 # pyre-ignore 201 ) # view @ idx2 202 self.assertEqual( 203 instructions[1].instr_args.op_index, 0 # pyre-ignore 204 ) # view @ idx3 205 self.assertEqual( 206 instructions[2].instr_args.op_index, 1 # pyre-ignore 207 ) # aten:mul @ idx6 208 self.assertEqual( 209 instructions[3].instr_args.op_index, 0 # pyre-ignore 210 ) # view @ idx7 211 self.assertEqual( 212 instructions[4].instr_args.op_index, 1 # pyre-ignore 213 ) # aten:mul @ idx9 214 self.assertEqual( 215 instructions[5].instr_args.op_index, 2 # pyre-ignore 216 ) # aten:view_copy @ idx11 217 self.assertEqual( 218 instructions[6].instr_args.op_index, 2 # pyre-ignore 219 ) # aten:view_copy @ idx11 220