1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import itertools 10import logging 11import operator 12import typing 13from collections import defaultdict 14from dataclasses import dataclass 15from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union 16 17import torch 18from executorch.exir import memory 19from executorch.exir.control_flow import while_loop as exir_while 20from executorch.exir.delegate import executorch_call_delegate 21from executorch.exir.error import internal_assert, InternalError 22from executorch.exir.operator.convert import is_inplace_variant, is_out_variant 23from executorch.exir.schema import TensorShapeDynamism 24from executorch.exir.tensor import TensorSpec 25 26from torch import fx 27from torch.export.exported_program import ExportGraphSignature 28from torch.fx import Node 29from torch.utils._pytree import tree_flatten 30 31REGISTERED_ALGOS: Dict[str, Callable[..., List[int]]] = {} 32 33 34class Verifier: 35 """ 36 Verify if the outcome of a memory planning algorithm makes sense. 37 E.g., make sure tensors having overlapping lifetime does not have overlapping 38 storage/buffer. 39 """ 40 41 def __init__( 42 self, 43 graph_module: torch.fx.GraphModule, 44 alloc_graph_input: bool, 45 alloc_graph_output: bool, 46 graph_signature: Optional[ExportGraphSignature] = None, 47 ) -> None: 48 self.graph_module = graph_module 49 self.graph_signature = graph_signature 50 self.alloc_graph_input = alloc_graph_input 51 self.alloc_graph_output = alloc_graph_output 52 53 @classmethod 54 def mem_obj_id_match( 55 cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec, accept_both_none: bool = True 56 ) -> bool: 57 """ 58 Given two `TensorSpec`, return if their `mem_obj_id` are the same. Note that if 59 both are None, this function will return True if `accept_both_none` is True and 60 False otherwise. 61 """ 62 if lhs_spec.mem_id != rhs_spec.mem_id: 63 return False 64 65 # both are None 66 if lhs_spec.mem_obj_id is None and rhs_spec.mem_obj_id is None: 67 return accept_both_none 68 69 return lhs_spec.mem_obj_id == rhs_spec.mem_obj_id 70 71 @classmethod 72 def has_overlap(cls, lhs_ivl: List[int], rhs_ivl: List[int]) -> bool: 73 r""" 74 The passed in intervals are inclusive in both sides. Return if they have 75 overlapping. 76 """ 77 # empty interval 78 if lhs_ivl[0] > lhs_ivl[1] or rhs_ivl[0] > rhs_ivl[1]: 79 return False 80 81 return (lhs_ivl[0] >= rhs_ivl[0] and lhs_ivl[0] <= rhs_ivl[1]) or ( 82 rhs_ivl[0] >= lhs_ivl[0] and rhs_ivl[0] <= lhs_ivl[1] 83 ) 84 85 @classmethod 86 def lifetime_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool: 87 lhs_lifetime = lhs_spec.lifetime 88 rhs_lifetime = rhs_spec.lifetime 89 internal_assert( 90 lhs_lifetime[0] is not None and lhs_lifetime[1] is not None, 91 f"{lhs_spec} should have valid start and end", 92 ) 93 internal_assert( 94 rhs_lifetime[0] is not None and rhs_lifetime[1] is not None, 95 f"{rhs_spec} should have valid start and end", 96 ) 97 return cls.has_overlap(lhs_lifetime, rhs_lifetime) 98 99 @classmethod 100 def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool: 101 intervals = [] 102 if lhs_spec.mem_id != rhs_spec.mem_id: 103 return False 104 for spec in [lhs_spec, rhs_spec]: 105 internal_assert( 106 spec.allocated_memory >= 0, 107 f"{spec} should have non-zero allocated memory", 108 ) 109 internal_assert( 110 isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, 111 f"{spec} should have specified memory offset", 112 ) 113 intervals.append( 114 [spec.mem_offset, spec.mem_offset + spec.allocated_memory - 1] 115 ) 116 has_overlap = cls.has_overlap(*intervals) 117 118 return has_overlap 119 120 def verify_storage_reuse( 121 self, allow_lifetime_and_storage_overlap: bool = False 122 ) -> int: 123 """ 124 'allow_lifetime_and_storage_overlap' allows tensors to overlap in both 125 lifetime and storage. If is it False, and two tensors have both overlapping 126 lifetime and storage, throw an exception. 127 Returns: 128 Number of pairs of tenors that have overlapping storage. 129 """ 130 num_reuse_pairs = 0 131 132 # unique tensors specs 133 all_specs = list( 134 collect_specs_from_nodes( 135 self.graph_module.graph.nodes, 136 self.graph_signature, 137 ignore_const=True, 138 ignore_graph_input=not self.alloc_graph_input, 139 ignore_graph_output=not self.alloc_graph_output, 140 do_assertion=False, 141 ignore_out_var_node=False, 142 dedup=True, 143 ) 144 ) 145 146 for lhs_spec_idx, lhs_spec in enumerate(all_specs): 147 for rhs_spec in all_specs[lhs_spec_idx + 1 :]: 148 # Check that both specs are consistent about whether mem_obj_id is defined 149 if (lhs_spec.mem_obj_id is None) != (rhs_spec.mem_obj_id is None): 150 raise InternalError( 151 "Specs do not agree on whether mem_obj_id is defined." 152 ) 153 154 has_storage_overlap = Verifier.storage_overlap(lhs_spec, rhs_spec) 155 if not has_storage_overlap: 156 continue 157 158 if not allow_lifetime_and_storage_overlap and self.lifetime_overlap( 159 lhs_spec, rhs_spec 160 ): 161 raise InternalError( 162 f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}" 163 ) 164 165 # Check that each mem_obj_id is consistent with whether the tensors have 166 # storage overlap 167 if not Verifier.mem_obj_id_match(lhs_spec, rhs_spec): 168 raise InternalError( 169 f"Unexpected mem_obj_id mismatch: lhs {lhs_spec}, rhs {rhs_spec}" 170 ) 171 172 num_reuse_pairs += 1 173 174 return num_reuse_pairs 175 176 def verify_graph_input_output(self) -> None: 177 r""" 178 alloc_graph_input / alloc_graph_output indicas if memory for graph 179 input/output is allocated by the compiler. If not, the runtime will 180 set them using buffers provided by users. 181 """ 182 graph_module = self.graph_module 183 # There is one tricky case here. If the graph input and graph output 184 # tensors have overlap, but alloc_graph_input != alloc_graph_output, 185 # then the overlapped tensor will cause assertion failure below. 186 # The current behavior is if either alloc_graph_input or alloc_graph_output 187 # is false, those overlapped tensor will not have memory allocated. 188 # 189 # Ignore the check in this case for now. 190 overlap = get_graph_input_tensors( 191 graph_module.graph.nodes, self.graph_signature 192 ) & get_graph_output_tensors(graph_module.graph.nodes) 193 if overlap and (self.alloc_graph_input != self.alloc_graph_output): 194 logging.debug( 195 "Having overlapping graph input/output tensors while the allocation decision for graph input/output mismatch." 196 ) 197 return 198 199 graph_input_allocated = None 200 graph_output_allocated = None 201 202 has_dynamic_unbound_input = False 203 has_dynamic_unbound_output = False 204 205 check_list = {"placeholder", "output"} & { 206 node.op for node in graph_module.graph.nodes 207 } 208 assert "output" in check_list, f"graph module has no output: {graph_module}" 209 210 for nd in graph_module.graph.nodes: 211 if nd.op in check_list: 212 if not (specs := get_node_tensor_specs(nd)): 213 continue 214 if _is_mutable_buffer(nd, self.graph_signature): 215 continue 216 assert len(specs) > 0, "Expect tensor specs" 217 specs = list(filter(lambda spec: not spec.const, specs)) 218 if len(specs) == 0: 219 continue 220 allocated = any( 221 spec is None or spec.mem_offset is not None for spec in specs 222 ) 223 has_dynamic_unbound_tensor = any( 224 spec is None 225 or spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 226 for spec in specs 227 ) 228 assert ( 229 all(spec is None or spec.mem_offset is not None for spec in specs) 230 == allocated 231 ), "Either all or non of the tensors should be allocated memory" 232 if nd.op == "placeholder": 233 graph_input_allocated = allocated 234 has_dynamic_unbound_input |= has_dynamic_unbound_tensor 235 else: 236 graph_output_allocated = allocated 237 has_dynamic_unbound_output |= has_dynamic_unbound_tensor 238 239 if "placeholder" in check_list: 240 assert graph_input_allocated is not None, "graph_input_allocated not set" 241 if not has_dynamic_unbound_input: 242 assert ( 243 graph_input_allocated == self.alloc_graph_input 244 ), f"Misallocate graph input: {graph_input_allocated} v.s. {self.alloc_graph_input}" 245 246 assert graph_output_allocated is not None, "graph_output_allocated not set" 247 if not has_dynamic_unbound_output: 248 assert ( 249 graph_output_allocated == self.alloc_graph_output 250 ), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}" 251 252 253def _is_out_var_node(node: torch.fx.Node) -> bool: 254 return ( 255 node.op == "call_function" 256 and isinstance(node.target, torch._ops.OpOverload) 257 and is_out_variant(node.target._schema.name, node.target._schema.overload_name) 258 ) 259 260 261def _is_inplace_node(node: torch.fx.Node) -> bool: 262 return ( 263 node.op == "call_function" 264 and isinstance(node.target, torch._ops.OpOverload) 265 and is_inplace_variant( 266 node.target._schema.name, node.target._schema.overload_name 267 ) 268 ) 269 270 271def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None: 272 r""" 273 Update the lifetime of the tensor to cover node_idx. A tensor's lifetime 274 are represented by the index of the first and last node referring 275 that tensor in its inputs/outputs. 276 277 Arguments: 278 spec: the TensorSpec for the tensor 279 node_idx: extend the tensor's lifetime to cover node_idx 280 """ 281 start, end = spec.lifetime 282 start = node_idx if start is None or start > node_idx else start 283 end = node_idx if end is None or end < node_idx else end 284 spec.lifetime = [start, end] 285 286 287# pyre-ignore 288def filter_nodes(inputs: Iterable[Any]) -> Iterable[Node]: 289 """ 290 This method need return Node object embedded inside List/Dict as well. 291 """ 292 return [nd for nd in tree_flatten(list(inputs))[0] if isinstance(nd, Node)] 293 294 295def _is_mutable_buffer( 296 node: Node, graph_signature: Optional[ExportGraphSignature] = None 297) -> bool: 298 """ 299 Check if the node is mutable buffer according to the provided graph signature. 300 """ 301 # graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them. 302 if graph_signature is None: 303 return False 304 if node.op == "placeholder": 305 if isinstance(node.target, str): 306 if node.target in graph_signature.inputs_to_buffers: 307 fqn = graph_signature.inputs_to_buffers[node.target] 308 # if the buffer is mutated then record that 309 if fqn in graph_signature.buffers_to_mutate.values(): 310 return True 311 return False 312 313 314def get_graph_input_tensors( 315 nodes: Iterable[Node], graph_signature: Optional[ExportGraphSignature] = None 316) -> Set[TensorSpec]: 317 graph_input_tensors = set() 318 for node in nodes: 319 if node.op == "placeholder" and not _is_mutable_buffer(node, graph_signature): 320 for spec in get_node_tensor_specs(node): 321 graph_input_tensors.add(spec) 322 323 return graph_input_tensors 324 325 326def get_graph_output_tensors(nodes: Iterable[Node]) -> Set[TensorSpec]: 327 graph_output_tensors = set() 328 for node in nodes: 329 if node.op == "output": 330 for spec in get_node_tensor_specs(node): 331 graph_output_tensors.add(spec) 332 333 return graph_output_tensors 334 335 336def collect_specs_from_nodes( # noqa: C901 337 nodes: Iterable[Node], 338 graph_signature: Optional[ExportGraphSignature] = None, 339 ignore_graph_input: bool = False, 340 ignore_graph_output: bool = False, 341 ignore_const: bool = True, 342 ignore_out_var_node: bool = True, 343 dedup: bool = True, 344 do_assertion: bool = True, 345 ignore_dynamic_unbound_tensor: bool = True, 346) -> Iterable[TensorSpec]: 347 r""" 348 Collect specs from the passed in nodes. Do filtering as controlled by 349 arguments. 350 Arguments: 351 ignore_graph_input: ignore graph input tensors from placeholder nodes 352 ignore_const: whether to ignore the const 353 ignore_out_var_node: whether to ignore out variant node 354 dedup: whether do dedup 355 do_assertion: whether to assert the filtered nodes belong to a resticted set like alloc, getitem 356 """ 357 unique_spec = set() 358 graph_input_tensors: Set[TensorSpec] = ( 359 get_graph_input_tensors(nodes, graph_signature) if ignore_graph_input else set() 360 ) 361 graph_output_tensors: Set[TensorSpec] = ( 362 get_graph_output_tensors(nodes) if ignore_graph_output else set() 363 ) 364 365 for node in nodes: 366 # ignore the specs from unrelevant Fx ops for now. 367 if node.op in ["get_attr"]: 368 continue 369 370 # don't reallocate memory for out-variant op's output tensors, 371 # since they are just input tenors. 372 if ignore_out_var_node and _is_out_var_node(node): 373 continue 374 375 if not (specs := get_node_tensor_specs(node)): 376 continue 377 378 if _is_inplace_node(node): 379 continue 380 381 if do_assertion: 382 internal_assert( 383 node.op in ("placeholder", "output") 384 or node.target 385 in [ 386 memory.alloc, 387 memory.view, 388 operator.getitem, 389 torch.ops.higher_order.cond, 390 exir_while, 391 torch.ops.higher_order.map_impl, 392 executorch_call_delegate, 393 ], 394 f"Unexpected op {node.op}, target {node.target}", 395 ) 396 for spec in specs: 397 if spec is None: 398 continue 399 # Dynamic unbound tensors' memory will be allocated by the runtime. 400 # Memory planning should ignore them. 401 if ( 402 ignore_dynamic_unbound_tensor 403 and spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 404 ): 405 continue 406 407 # Note: graph input may be the output of other ops (e.g. the return op) 408 # If ignore_graph_input is true, we should ignore those Tensor so 409 # we skip planning memory for graph input. 410 if ignore_graph_input and spec in graph_input_tensors: 411 continue 412 if ignore_graph_output and spec in graph_output_tensors: 413 continue 414 if ( 415 ignore_const 416 and spec.const 417 and not node.meta.get("weight_has_gradient", False) 418 ): 419 continue 420 if dedup: 421 if spec in unique_spec: 422 continue 423 else: 424 unique_spec.add(spec) 425 yield spec 426 427 428def update_all_tensors_lifetime( 429 graph_module: torch.fx.GraphModule, 430 graph_signature: Optional[ExportGraphSignature] = None, 431) -> Set[TensorSpec]: 432 r""" 433 Set the lifetime for all the tensors encountered in the Fx graph. 434 """ 435 specs = set() 436 for node_idx, node in enumerate(graph_module.graph.nodes): 437 for spec in collect_specs_from_nodes( 438 filter_nodes(itertools.chain([node], node.args, node.kwargs.values())), 439 graph_signature, 440 ignore_graph_input=False, 441 ignore_const=False, 442 ignore_out_var_node=False, 443 dedup=False, 444 do_assertion=False, 445 ignore_dynamic_unbound_tensor=False, 446 ): 447 update_tensor_lifetime(spec, node_idx) 448 specs.add(spec) 449 return specs 450 451 452@dataclass 453class SharedObject: 454 r""" 455 We define the concept of shared object, which represents a segment 456 in the memory buffer that can be shared by multiple tensors. In order to 457 check if a shared object is available for a tensor, we maintain the 458 last_used_index attribute. The shared object will be available for nodes 459 with index greater than last_used_index. 460 """ 461 462 # index of the shared object in the list of shared objects, used as a unique id 463 idx: int 464 # offset in the memory buffer 465 offset: int 466 # size of this shared object in bytes 467 size: int 468 # the object will be available for index (last_used_index + 1) 469 last_used_index: int 470 471 472def materialize_buffer( 473 shared_objects: List[SharedObject], input_total_size: int = 0 474) -> int: 475 r""" 476 Assign concrete location in the buffer for each SharedObject.offset. 477 478 Assuming all the passed in shared objects belong to the same memory buffer. 479 """ 480 total_size = input_total_size 481 for sobj in shared_objects: 482 sobj.offset = total_size 483 total_size += sobj.size 484 return total_size 485 486 487def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int: 488 r""" 489 Calculate the absolute different between the size of a shared object and 490 a tensor. 491 """ 492 return abs(sobj.size - spec.allocated_memory) 493 494 495def pick_shared_obj( 496 shared_objects: List[SharedObject], spec: TensorSpec 497) -> SharedObject: 498 r""" 499 Pick the available shared object with closest size to the tensor. 500 If there are no available shared object left, create a new one. 501 """ 502 # TODO: do better than linear scan 503 picked = None 504 for sobj in shared_objects: 505 if spec.lifetime[0] > sobj.last_used_index: 506 if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif( 507 picked, spec 508 ): 509 picked = sobj 510 sobj.last_used_index = spec.lifetime[1] 511 sobj.size = max(sobj.size, spec.allocated_memory) 512 if picked is None: 513 picked = SharedObject( 514 len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1] 515 ) 516 shared_objects.append(picked) 517 518 return picked 519 520 521def get_node_tensor_specs( 522 node: torch.fx.Node, 523) -> Union[List[TensorSpec], Tuple[TensorSpec]]: 524 r""" 525 Return the list of the tensor specs for the node or empty list if the node 526 has no tensor specs. 527 """ 528 # get tensor specs 529 if node.target == memory.view: 530 base = node.args[0] 531 assert isinstance(base, torch.fx.Node) 532 specs = base.meta.get("spec") 533 else: 534 specs = node.meta.get("spec") 535 536 if isinstance(specs, TensorSpec): 537 specs = [specs] 538 if not isinstance(specs, (list, tuple)): 539 return [] 540 else: 541 return [ 542 spec 543 for spec in specs 544 if not isinstance(spec, (int, float, bool, str, type(None))) 545 ] 546 547 548def greedy( 549 graph_module: torch.fx.GraphModule, 550 alignment: int, 551 graph_signature: Optional[ExportGraphSignature] = None, 552 alloc_graph_input: bool = True, 553 alloc_graph_output: bool = True, 554) -> List[int]: 555 spec2obj = {} 556 shared_objects = defaultdict(list) 557 # Don't do assertion in collect_specs_from_nodes if we have already encountered 558 # and ignored some to_out_variant errors. 559 do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False) 560 # For each tensor, pick the available shared object with closest size to 561 # the tensor. If there are no available shared object left, create a new 562 # one. 563 for spec in collect_specs_from_nodes( 564 graph_module.graph.nodes, 565 graph_signature, 566 do_assertion=do_assertion, 567 ignore_graph_input=not alloc_graph_input, 568 ignore_graph_output=not alloc_graph_output, 569 ): 570 if spec.mem_id is None: 571 spec.mem_id = 1 572 spec.realign(alignment) 573 spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec) 574 575 if len(shared_objects) == 0: 576 # Cannot find any tensor in the graph that needs to be allocated. 577 # Return [0, 0] to be consistent with default behavior of naive. 578 total_sizes = [0, 0] 579 else: 580 total_sizes = [0] * (max(shared_objects.keys()) + 1) 581 for mem_id in shared_objects: 582 input_total_size = 0 583 if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None): 584 # pyre-fixme[6]: For 1st argument expected 585 # `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`. 586 if len(bufsizes) > mem_id: 587 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten... 588 input_total_size = bufsizes[mem_id] 589 total_sizes[mem_id] = materialize_buffer( 590 shared_objects[mem_id], input_total_size 591 ) 592 593 # Since we now know the number of shared objects we need and the size of 594 # each shared object, we can assign offset in the memory buffer for each 595 # shared object. 596 for spec, sobj in spec2obj.items(): 597 spec.mem_obj_id = sobj.idx 598 spec.mem_offset = sobj.offset 599 600 logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}") 601 return total_sizes 602 603 604def naive( 605 graph_module: torch.fx.GraphModule, 606 alignment: int, 607 graph_signature: Optional[ExportGraphSignature] = None, 608 alloc_graph_input: bool = True, 609 alloc_graph_output: bool = True, 610) -> List[int]: 611 612 # allocate 'allocated' bytes from buffer with id mem_id. 613 # return the starting offset of the allocated buffer. 614 def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: 615 if mem_id >= len(bufsizes): 616 bufsizes.extend([0] * (mem_id - len(bufsizes) + 1)) 617 ret = bufsizes[mem_id] 618 bufsizes[mem_id] += allocated 619 return ret 620 621 bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None) 622 if bufsizes is None: 623 bufsizes = [0, 0] 624 625 bufsizes = typing.cast(List[int], bufsizes) 626 for spec in collect_specs_from_nodes( 627 graph_module.graph.nodes, 628 graph_signature, 629 ignore_graph_input=not alloc_graph_input, 630 ignore_graph_output=not alloc_graph_output, 631 ): 632 # assume a single memory layer which has mem_id 1 633 if spec.mem_id is None: 634 spec.mem_id = 1 635 # allocate spec.allocated_memory bytes in the buffer 636 # with the corresponding mem_id 637 spec.realign(alignment) 638 spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory) 639 640 logging.debug(f"naive algorithm returns bufsizes: {bufsizes}") 641 return bufsizes 642 643 644def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: 645 for nd in graph_module.graph.nodes: 646 if nd.target is torch.ops.higher_order.cond: 647 yield nd 648 649 650def get_while_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: 651 for nd in graph_module.graph.nodes: 652 if nd.target is exir_while: 653 yield nd 654 655 656def get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: 657 for nd in graph_module.graph.nodes: 658 if nd.target is torch.ops.higher_order.map_impl: 659 yield nd 660 661 662def get_return_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]: 663 return_specs = set() 664 nodes = graph_module.graph.nodes 665 if len(nodes) > 0: 666 last_node = next(iter(reversed(nodes))) 667 for spec in tree_flatten(last_node.meta["spec"])[0]: 668 return_specs.add(spec) 669 return return_specs 670 671 672def get_input_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]: 673 input_specs = set() 674 nodes = graph_module.graph.nodes 675 for node in nodes: 676 if node.op == "placeholder": 677 for spec in tree_flatten(node.meta["spec"])[0]: 678 input_specs.add(spec) 679 return input_specs 680 681 682def insert_calls_to_free( 683 graph_module: fx.GraphModule, allspecs: Set[TensorSpec] 684) -> None: 685 """ 686 Insert calls to free for dynamic unbound tensors that goes out of lifetime. 687 688 Only handle the module itself. Submodule is handles in separate calls of 689 this function. 690 691 NOTE: this method will invalidate lifetime recorded in TensorSpec because 692 of extra free node added to the graph. 693 """ 694 # Note: we should never free a output tensor 695 return_specs = get_return_specs(graph_module) 696 # Note: we should never free a input tensor since buffer for input tensor 697 # may be passed in from user. 698 input_specs = get_input_specs(graph_module) 699 idx_to_dead_specs = defaultdict(list) 700 for spec in allspecs: 701 if ( 702 spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 703 and spec not in return_specs 704 and spec not in input_specs 705 ): 706 idx_to_dead_specs[spec.lifetime[1]].append(spec) 707 708 num_nodes = len(graph_module.graph.nodes) 709 # iterate in reverse order so inserted node does not disturbe node 710 # numbering. 711 for node, node_idx in zip( 712 reversed(graph_module.graph.nodes), range(num_nodes - 1, -1, -1) 713 ): 714 dead_specs = idx_to_dead_specs.get(node_idx, []) 715 if not dead_specs: 716 continue 717 with graph_module.graph.inserting_after(node): 718 for spec in dead_specs: 719 graph_module.graph.call_function(memory.free, (spec,)) 720 graph_module.recompile() 721 722 723def apply_algo( 724 algo: Callable[ 725 [torch.fx.GraphModule, int, Optional[ExportGraphSignature], bool, bool], 726 List[int], 727 ], 728 graph_module: torch.fx.GraphModule, 729 alignment: int, 730 graph_signature: Optional[ExportGraphSignature] = None, 731 alloc_graph_input: bool = True, 732 alloc_graph_output: bool = True, 733) -> List[int]: 734 """ 735 Recursively apply algo to graph_module and its submodules for control flow. 736 737 Quite naively right now since it does not take the following optimizations 738 into considerating: 739 1. for conditional structure, true branch and false true does not overlap 740 in lifetime and can share tensor storage 741 2. tensors inside a submodule (e.g. true branch) has opportunities to share 742 storage with tensors in the outer module. 743 TODO: make these optimizations once we have some baseline working. 744 """ 745 specs = update_all_tensors_lifetime(graph_module, graph_signature) 746 bufsizes: List[int] = algo( 747 graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output 748 ) 749 insert_calls_to_free(graph_module, specs) 750 751 def handle_submodule( 752 submodule_nd: torch.fx.Node, alloc_graph_input: bool = False 753 ) -> None: 754 nonlocal bufsizes 755 assert submodule_nd.op == "get_attr" 756 submodule = getattr(graph_module, submodule_nd.target) 757 # memory planning for submodule need to be aware of the amount of 758 # buffer already allocated. 759 submodule.input_mem_buffer_sizes = bufsizes 760 bufsizes = apply_algo( 761 algo, 762 submodule, 763 alignment, 764 graph_signature, 765 alloc_graph_input=alloc_graph_input, 766 alloc_graph_output=True, 767 ) 768 submodule.meta.update({"non_const_buffer_sizes": bufsizes}) 769 770 for cond_node in get_cond_nodes(graph_module): 771 handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1])) 772 handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2])) 773 774 for while_node in get_while_nodes(graph_module): 775 handle_submodule(typing.cast(torch.fx.Node, while_node.args[0])) 776 handle_submodule(typing.cast(torch.fx.Node, while_node.args[1])) 777 # TODO: Add test coverage for map operator once dynamo tracing is 778 # fully supported for this. T142287208 779 for map_node in get_map_nodes(graph_module): 780 handle_submodule( 781 typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True 782 ) 783 784 graph_module.meta.update({"non_const_buffer_sizes": bufsizes}) 785 786 return bufsizes 787