1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import itertools 10import unittest 11from typing import Any, Callable, List, Optional, Tuple, Type 12 13import executorch.exir as exir 14 15import torch 16from executorch.exir import ExecutorchBackendConfig, to_edge 17from executorch.exir.memory_planning import ( 18 filter_nodes, 19 get_node_tensor_specs, 20 greedy, 21 naive, 22 Verifier, 23) 24from executorch.exir.pass_base import PassResult 25from executorch.exir.pass_manager import PassManager 26from executorch.exir.passes import ( # noqa 27 MemoryPlanningPass, 28 SpecPropPass, 29 ToOutVarPass, 30) 31from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass 32from parameterized import parameterized 33 34from torch import nn 35from torch.ao.quantization import ( # @manual=//caffe2:torch 36 float_qparams_weight_only_qconfig, 37) 38from torch.ao.quantization.backend_config.executorch import ( 39 get_executorch_backend_config, 40) 41from torch.ao.quantization.observer import ( 42 default_dynamic_quant_observer, 43 default_per_channel_weight_observer, 44) 45from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping 46from torch.ao.quantization.quantize_fx import ( 47 _convert_to_reference_decomposed_fx, 48 prepare_fx, 49) 50from torch.export import export 51from torch.export.exported_program import ExportGraphSignature 52from torch.fx import Graph, GraphModule, Node 53from torch.nn import functional as F 54 55torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib") 56 57 58def swap_modules( 59 module: torch.nn.Module, 60 condition: Callable[[torch.nn.Module], bool], 61 convert_func: Callable[[torch.nn.Module], torch.nn.Module], 62) -> None: 63 reassign = {} 64 for name, mod in module.named_children(): 65 swap_modules(mod, condition, convert_func) 66 if condition(mod): 67 out = convert_func(mod) 68 reassign[name] = out 69 for key, value in reassign.items(): 70 module._modules[key] = value 71 72 73class ToyModelForMemPlanning(torch.nn.Module): 74 def __init__(self) -> None: 75 super(ToyModelForMemPlanning, self).__init__() 76 77 def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 78 o = a 79 for _ in range(10): 80 o = o * a 81 o = o + b 82 return o 83 84 def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 85 return (torch.randn(10), torch.randn(10)) 86 87 88class ModelWithDifferentTensorSizes(torch.nn.Module): 89 def __init__(self) -> None: 90 super(ModelWithDifferentTensorSizes, self).__init__() 91 self.linears = torch.nn.ModuleList() 92 for x in [2, 4, 8, 16, 32, 64, 128]: 93 self.linears.append(torch.nn.Linear(x, x * 2)) 94 95 def forward(self, i: torch.Tensor) -> torch.Tensor: 96 o1 = i 97 for linear in self.linears: 98 o1 = linear(o1) 99 o2 = i 100 for linear in self.linears: 101 o2 = linear(o2) 102 return o1 + o2 103 104 def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 105 return (torch.randn(2),) 106 107 108class ModuleReturnTwo(nn.Module): 109 def __init__(self) -> None: 110 super(ModuleReturnTwo, self).__init__() 111 self.linear1 = nn.Linear(8, 8) 112 self.linear2 = nn.Linear(8, 8) 113 114 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 115 o1 = self.linear1(x) 116 o2 = self.linear2(x) 117 return o1, o2 118 119 def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 120 return (torch.randn(8),) 121 122 123class ModuleListArg(nn.Module): 124 r""" 125 The module split a tensor and concat the parts again. The cat op will receive 126 a list of tensors as argument. We want to make sure we can handle lifetime 127 of tensors embedded inside a list arg correctly. 128 """ 129 130 def __init__(self) -> None: 131 super(ModuleListArg, self).__init__() 132 133 def forward(self, a: torch.Tensor) -> torch.Tensor: 134 s0, s1 = torch.tensor_split(a, 2) 135 s = torch.cat([s0, s1], 0) 136 return s 137 138 def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 139 return (torch.randn(8),) 140 141 @staticmethod 142 def extra_check( 143 testcase: unittest.TestCase, graph_module: torch.fx.GraphModule 144 ) -> None: 145 """ 146 Make sure the getitem nodes live as long as when the cat node starts alive 147 since the cat node should have a list argument containing all the getitem nodes. 148 """ 149 getitem_specs = [] 150 cat_specs = [] 151 for node in graph_module.graph.nodes: 152 if node.target == torch.ops.aten.cat.out: 153 cat_specs.append(node.meta["spec"]) 154 elif node.target == torch.ops.aten.slice_copy.Tensor_out: 155 getitem_specs.append(node.meta["spec"]) 156 157 testcase.assertEqual(1, len(cat_specs)) 158 testcase.assertEqual(2, len(getitem_specs)) 159 for getitem_spec in getitem_specs: 160 testcase.assertTrue(getitem_spec.lifetime[1] >= cat_specs[0].lifetime[0]) 161 162 163class CustomPoolMemoryPlanningPass(MemoryPlanningPass): 164 def call(self, graph_module: GraphModule) -> PassResult: 165 for subgm in graph_module.modules(): 166 if not isinstance(subgm, GraphModule): 167 continue 168 for node in subgm.graph.nodes: 169 # mem_id = 1 placeholder and outputs of mul 170 # mem_id = 3 for outputs of add 171 # parent class will copy spec will to alloc nodes 172 if node.op == "placeholder": 173 node.meta["spec"].mem_id = 1 174 continue 175 176 if node.op != "call_function": 177 continue 178 179 if node.target == torch.ops.aten.add.out: 180 node.meta["spec"].mem_id = 3 181 elif node.target == torch.ops.aten.mul.out: 182 node.meta["spec"].mem_id = 1 183 184 return super().run(graph_module) 185 186 def run( 187 self, 188 graph_module: torch.fx.GraphModule, 189 graph_signature: Optional[ExportGraphSignature] = None, 190 ) -> PassResult: 191 return self.call(graph_module) 192 193 194class MultiplePoolsToyModel(torch.nn.Module): 195 def forward(self, a: torch.Tensor) -> torch.Tensor: 196 # a: mem_id = 1, offset = 0 197 # b: mem_id = 3, offset = 0 198 # c: mem_id = 1, offset = 4 199 # d: mem_id = 3, offset = 4 200 # greedy: 201 # e: mem_id = 1, offset = 0 202 # naive: 203 # e: mem_id = 1, offset = 8 204 b = a + a 205 c = a * b 206 d = c + b 207 e = c * d 208 return e 209 210 211def maketest( 212 module_cls: Type[torch.nn.Module], 213 criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None, 214 extra_check: Optional[Callable[..., None]] = None, 215 use_functionalization: bool = True, 216 alloc_graph_input: bool = True, 217 alloc_graph_output: bool = True, 218 has_unused_graph_input: bool = False, 219) -> Callable[..., None]: 220 # parameterized.expand is not compatible with maketest. I'll just loop thru 221 # the test setups in the wrapper. 222 def wrapper(self: "TestMemoryPlanning") -> None: 223 nonlocal criteria 224 if not criteria: 225 criteria = [ 226 # naive algorithm does not reuse tensor storages 227 (naive, False), 228 # greedy algorithm should reuse tensor storages in the testing model 229 (greedy, True), 230 ] 231 232 for algo, expect_reuse in criteria: 233 print( 234 f"algo {getattr(algo, '__name__', repr(algo))}, expect_reuse {expect_reuse}" 235 ) 236 eager_module = module_cls().eval() 237 # pyre-fixme[29]: `Union[nn.modules.module.Module, 238 # torch._tensor.Tensor]` is not a function. 239 inputs = eager_module.get_random_inputs() 240 graph_module = ( 241 to_edge( 242 export( 243 eager_module, 244 inputs, 245 ) 246 ) 247 .exported_program() 248 .graph_module 249 ) 250 251 graph_module = PassManager( 252 passes=[ 253 SpecPropPass(), 254 ToOutVarPass(), 255 MemoryPlanningPass( 256 algo, 257 alloc_graph_input=alloc_graph_input, 258 alloc_graph_output=alloc_graph_output, 259 ), 260 ], 261 )(graph_module).graph_module 262 263 self.verify_reuse( 264 graph_module, expect_reuse, alloc_graph_input, alloc_graph_output 265 ) 266 self.verify_graph_input_output( 267 graph_module, alloc_graph_input, alloc_graph_output 268 ) 269 270 self.verify_overlap_placeholders(has_unused_graph_input, graph_module) 271 272 # print(f"Final code: {graph_module.code}") 273 # print(f"Final graph: {graph_module.graph}") 274 275 if extra_check: 276 extra_check(self, graph_module) 277 278 return wrapper 279 280 281class TestMemoryPlanning(unittest.TestCase): 282 def verify_reuse( 283 self, 284 graph_module: torch.fx.GraphModule, 285 expect_reuse: bool, 286 alloc_graph_input: bool, 287 alloc_graph_output: bool, 288 ) -> None: 289 r""" 290 Do sanity check and verify tensor storage reuse. 291 292 There should NOT be any tensor storage overlapping between tensors that have 293 overlapping lifetime. 294 295 expect_reuse is True if we expect the algorithm reuse tensor storages 296 for at least a pair of tensors in the current testing setup. 297 """ 298 # this method throws if 2 tensors overlap both lifetime and storage. 299 num_reuse_pairs = Verifier( 300 graph_module, 301 alloc_graph_input=alloc_graph_input, 302 alloc_graph_output=alloc_graph_output, 303 ).verify_storage_reuse() 304 305 print(f"num_reuse_pairs is {num_reuse_pairs}") 306 if expect_reuse: 307 self.assertTrue(num_reuse_pairs > 0) 308 else: 309 self.assertTrue(num_reuse_pairs == 0) 310 311 def verify_graph_input_output( 312 self, 313 graph_module: torch.fx.GraphModule, 314 alloc_graph_input: bool, 315 alloc_graph_output: bool, 316 ) -> None: 317 Verifier( 318 graph_module, alloc_graph_input, alloc_graph_output 319 ).verify_graph_input_output() 320 321 def verify_overlap_placeholders( 322 self, has_unused_graph_input: bool, graph_module: GraphModule 323 ) -> None: 324 """ 325 If every placholder node is used somewhere, then each pair should have 326 overlapped lifetime. 327 """ 328 if has_unused_graph_input: 329 return 330 331 ph_list = [] 332 for nd in graph_module.graph.nodes: 333 if nd.op == "placeholder": 334 ph_list.append(nd) 335 336 # since all placeholders are used somewhere. Their lifetime should 337 # overlap. 338 for i in range(len(ph_list)): 339 for j in range(i + 1, len(ph_list)): 340 ph_lhs = ph_list[i] 341 ph_rhs = ph_list[j] 342 self.assertTrue( 343 Verifier.lifetime_overlap(ph_lhs.meta["spec"], ph_rhs.meta["spec"]) 344 ) 345 346 test_basic: Callable[..., None] = maketest(ToyModelForMemPlanning) 347 # TODO(zhxchen17) re-enable this. 348 # test_while: Callable[..., None] = maketest( 349 # ModuleWhile, 350 # criteria=[ 351 # ("naive", False), 352 # ("greedy", False), 353 # ], 354 # ) 355 test_different_tensor_sizes: Callable[..., None] = maketest( 356 ModelWithDifferentTensorSizes 357 ) 358 359 test_return_two: Callable[..., None] = maketest( 360 ModuleReturnTwo, 361 criteria=[ 362 (naive, False), 363 (greedy, True), 364 ], 365 ) 366 367 # greedy algorithm will reuse memory if we let the algorithm allocate 368 # memory for both graph input and output. 369 test_list_arg: Callable[..., None] = maketest( 370 ModuleListArg, 371 criteria=[ 372 (naive, False), 373 (greedy, True), 374 ], 375 extra_check=ModuleListArg.extra_check, 376 ) 377 378 def test_graph_input_output(self) -> None: 379 for alloc_graph_input, alloc_graph_output in itertools.product( 380 [True, False], [True, False] 381 ): 382 case = maketest( 383 ModelWithDifferentTensorSizes, 384 alloc_graph_input=alloc_graph_input, 385 alloc_graph_output=alloc_graph_output, 386 ) 387 case(self) 388 389 390class TestVerifier(unittest.TestCase): 391 def test_overlap(self) -> None: 392 # first enclose second 393 self.assertTrue(Verifier.has_overlap([1, 10], [2, 3])) 394 # second enclose first 395 self.assertTrue(Verifier.has_overlap([2, 3], [1, 10])) 396 # first on the left side 397 self.assertTrue(Verifier.has_overlap([1, 4], [2, 5])) 398 # first on the right side 399 self.assertTrue(Verifier.has_overlap([2, 5], [1, 4])) 400 401 # non overlap. first on the left side 402 self.assertFalse(Verifier.has_overlap([1, 2], [5, 6])) 403 # non overlap. first on the right side 404 self.assertFalse(Verifier.has_overlap([5, 6], [1, 2])) 405 406 407class TestMisc(unittest.TestCase): 408 def test_filter_nodes(self) -> None: 409 g = Graph() 410 nd_pool = [ 411 Node(g, f"n{idx}", "placeholder", f"n{idx}", (), {}) for idx in range(10) 412 ] 413 actual_list = list( 414 filter_nodes( 415 [ 416 nd_pool[0], 417 (nd_pool[1], nd_pool[2]), 418 None, 419 [nd_pool[3]], 420 {"first": nd_pool[4]}, 421 ] 422 ) 423 ) 424 expected_list = nd_pool[:5] 425 self.assertEqual(len(actual_list), len(expected_list)) 426 for act, exp in zip(actual_list, expected_list): 427 self.assertEqual(id(act), id(exp)) 428 429 def quantize(self, eager_model: nn.Module) -> nn.Module: 430 quantized_model = eager_model 431 linear_qconfig_mapping = QConfigMapping().set_object_type( 432 F.linear, 433 QConfig( 434 activation=default_dynamic_quant_observer, 435 weight=default_per_channel_weight_observer, 436 ), 437 ) 438 embedding_qconfig_mapping = QConfigMapping().set_object_type( 439 F.embedding, 440 float_qparams_weight_only_qconfig, 441 ) 442 # quantize module 443 swap_modules( 444 quantized_model, 445 lambda mod: isinstance(mod, torch.nn.Linear), 446 lambda mod: _convert_to_reference_decomposed_fx( 447 prepare_fx( 448 mod, 449 linear_qconfig_mapping, 450 (torch.rand(1, mod.in_features),), 451 backend_config=get_executorch_backend_config(), 452 ), 453 backend_config=get_executorch_backend_config(), 454 ), 455 ) 456 swap_modules( 457 quantized_model, 458 lambda mod: isinstance(mod, torch.nn.Embedding), 459 lambda mod: _convert_to_reference_decomposed_fx( 460 prepare_fx( 461 mod, 462 embedding_qconfig_mapping, 463 (torch.ones(1, 1),), 464 backend_config=get_executorch_backend_config(), 465 ), 466 backend_config=get_executorch_backend_config(), 467 ), 468 ) 469 return quantized_model 470 471 # pyre-ignore 472 @parameterized.expand( 473 [ 474 ( 475 naive, 476 [(1, 0), (3, 0), (1, 4), (3, 4), (1, 8)], 477 [0, 12, 0, 8], 478 ), 479 ( 480 greedy, 481 [(1, 0), (3, 0), (1, 4), (3, 4), (1, 0)], 482 [0, 8, 0, 8], 483 ), 484 ] 485 ) 486 def test_multiple_pools( 487 self, 488 algo: Callable[..., List[int]], 489 expected_allocs: List[Tuple[int, int]], 490 expected_bufsizes: List[int], 491 ) -> None: 492 edge_program = to_edge( 493 export( 494 MultiplePoolsToyModel(), 495 (torch.ones(1),), 496 ) 497 ) 498 499 edge_program.to_executorch( 500 exir.ExecutorchBackendConfig( 501 memory_planning_pass=CustomPoolMemoryPlanningPass( 502 memory_planning_algo=algo, 503 alignment=1, 504 ), 505 ) 506 ) 507 graph_module = edge_program.exported_program().graph_module 508 509 verifier = Verifier( 510 graph_module, 511 alloc_graph_input=True, 512 alloc_graph_output=True, 513 ) 514 verifier.verify_storage_reuse() 515 verifier.verify_graph_input_output() 516 517 idx = 0 518 for node in graph_module.graph.nodes: 519 if node.op == "placeholder" or ( 520 node.op == "call_function" 521 and node.target in (torch.ops.aten.add.out, torch.ops.aten.mul.out) 522 ): 523 mem_id, mem_offset = expected_allocs[idx] 524 self.assertEqual(node.meta["spec"].mem_id, mem_id) 525 self.assertEqual(node.meta["spec"].mem_offset, mem_offset) 526 idx += 1 527 self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes) 528 529 def test_constants_not_memory_planned(self) -> None: 530 class Simple(torch.nn.Module): 531 def __init__(self) -> None: 532 super().__init__() 533 self.linear = torch.nn.Linear(5, 5) 534 self.register_buffer("constant", torch.ones(5, 5)) 535 536 def forward(self, x: torch.Tensor) -> torch.Tensor: 537 return torch.nn.functional.sigmoid(self.linear(x) + self.constant + 1) 538 539 def count_planned_inputs( 540 nodes: List[Node], graph_signature: Any # pyre-ignore 541 ) -> Tuple[int, int]: 542 num_mem_planned_placeholders = 0 543 num_placeholders = 0 544 for node in nodes: 545 if node.op == "placeholder": 546 num_placeholders += 1 547 specs = get_node_tensor_specs(node) 548 self.assertGreaterEqual(len(specs), 1) 549 for spec in specs: 550 if spec.mem_id is not None: 551 num_mem_planned_placeholders += 1 552 return num_placeholders, num_mem_planned_placeholders 553 554 model = Simple() 555 inputs = (torch.randn(5, 5),) 556 557 ep_no_input_planning = to_edge(export(model, inputs)).to_executorch( 558 config=ExecutorchBackendConfig( 559 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 560 sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), 561 ) 562 ) 563 564 num_placeholders, num_planned_placeholders = count_planned_inputs( 565 ep_no_input_planning.exported_program().graph_module.graph.nodes, 566 ep_no_input_planning.exported_program().graph_signature, 567 ) 568 self.assertEqual( 569 num_planned_placeholders, 570 0, 571 ) # one unplanned user input and 4 constants that shouldnt be planned 572 self.assertEqual( 573 num_placeholders, 574 5, # x, self.constant, linear weight, linear bias, '1' scalar promoted to tensor 575 ) 576 577 ep_input_planning = to_edge(export(model, inputs)).to_executorch( 578 config=ExecutorchBackendConfig( 579 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), 580 sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), 581 ) 582 ) 583 584 num_placeholders, num_planned_placeholders = count_planned_inputs( 585 ep_input_planning.exported_program().graph_module.graph.nodes, 586 ep_input_planning.exported_program().graph_signature, 587 ) 588 self.assertEqual( 589 num_planned_placeholders, 590 1, 591 ) # one planned user input and 4 constants that shouldnt be planned 592 self.assertEqual( 593 num_placeholders, 594 5, 595 ) 596