xref: /aosp_15_r20/external/executorch/exir/memory_planning.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import itertools
10import logging
11import operator
12import typing
13from collections import defaultdict
14from dataclasses import dataclass
15from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
16
17import torch
18from executorch.exir import memory
19from executorch.exir.control_flow import while_loop as exir_while
20from executorch.exir.delegate import executorch_call_delegate
21from executorch.exir.error import internal_assert, InternalError
22from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
23from executorch.exir.schema import TensorShapeDynamism
24from executorch.exir.tensor import TensorSpec
25
26from torch import fx
27from torch.export.exported_program import ExportGraphSignature
28from torch.fx import Node
29from torch.utils._pytree import tree_flatten
30
31REGISTERED_ALGOS: Dict[str, Callable[..., List[int]]] = {}
32
33
34class Verifier:
35    """
36    Verify if the outcome of a memory planning algorithm makes sense.
37    E.g., make sure tensors having overlapping lifetime does not have overlapping
38    storage/buffer.
39    """
40
41    def __init__(
42        self,
43        graph_module: torch.fx.GraphModule,
44        alloc_graph_input: bool,
45        alloc_graph_output: bool,
46        graph_signature: Optional[ExportGraphSignature] = None,
47    ) -> None:
48        self.graph_module = graph_module
49        self.graph_signature = graph_signature
50        self.alloc_graph_input = alloc_graph_input
51        self.alloc_graph_output = alloc_graph_output
52
53    @classmethod
54    def mem_obj_id_match(
55        cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec, accept_both_none: bool = True
56    ) -> bool:
57        """
58        Given two `TensorSpec`, return if their `mem_obj_id` are the same. Note that if
59        both are None, this function will return True if `accept_both_none` is True and
60        False otherwise.
61        """
62        if lhs_spec.mem_id != rhs_spec.mem_id:
63            return False
64
65        # both are None
66        if lhs_spec.mem_obj_id is None and rhs_spec.mem_obj_id is None:
67            return accept_both_none
68
69        return lhs_spec.mem_obj_id == rhs_spec.mem_obj_id
70
71    @classmethod
72    def has_overlap(cls, lhs_ivl: List[int], rhs_ivl: List[int]) -> bool:
73        r"""
74        The passed in intervals are inclusive in both sides. Return if they have
75        overlapping.
76        """
77        # empty interval
78        if lhs_ivl[0] > lhs_ivl[1] or rhs_ivl[0] > rhs_ivl[1]:
79            return False
80
81        return (lhs_ivl[0] >= rhs_ivl[0] and lhs_ivl[0] <= rhs_ivl[1]) or (
82            rhs_ivl[0] >= lhs_ivl[0] and rhs_ivl[0] <= lhs_ivl[1]
83        )
84
85    @classmethod
86    def lifetime_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
87        lhs_lifetime = lhs_spec.lifetime
88        rhs_lifetime = rhs_spec.lifetime
89        internal_assert(
90            lhs_lifetime[0] is not None and lhs_lifetime[1] is not None,
91            f"{lhs_spec} should have valid start and end",
92        )
93        internal_assert(
94            rhs_lifetime[0] is not None and rhs_lifetime[1] is not None,
95            f"{rhs_spec} should have valid start and end",
96        )
97        return cls.has_overlap(lhs_lifetime, rhs_lifetime)
98
99    @classmethod
100    def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
101        intervals = []
102        if lhs_spec.mem_id != rhs_spec.mem_id:
103            return False
104        for spec in [lhs_spec, rhs_spec]:
105            internal_assert(
106                spec.allocated_memory >= 0,
107                f"{spec} should have non-zero allocated memory",
108            )
109            internal_assert(
110                isinstance(spec.mem_offset, int) and spec.mem_offset >= 0,
111                f"{spec} should have specified memory offset",
112            )
113            intervals.append(
114                [spec.mem_offset, spec.mem_offset + spec.allocated_memory - 1]
115            )
116        has_overlap = cls.has_overlap(*intervals)
117
118        return has_overlap
119
120    def verify_storage_reuse(
121        self, allow_lifetime_and_storage_overlap: bool = False
122    ) -> int:
123        """
124        'allow_lifetime_and_storage_overlap' allows tensors to overlap in both
125        lifetime and storage. If is it False, and two tensors have both overlapping
126        lifetime and storage, throw an exception.
127        Returns:
128            Number of pairs of tenors that have overlapping storage.
129        """
130        num_reuse_pairs = 0
131
132        # unique tensors specs
133        all_specs = list(
134            collect_specs_from_nodes(
135                self.graph_module.graph.nodes,
136                self.graph_signature,
137                ignore_const=True,
138                ignore_graph_input=not self.alloc_graph_input,
139                ignore_graph_output=not self.alloc_graph_output,
140                do_assertion=False,
141                ignore_out_var_node=False,
142                dedup=True,
143            )
144        )
145
146        for lhs_spec_idx, lhs_spec in enumerate(all_specs):
147            for rhs_spec in all_specs[lhs_spec_idx + 1 :]:
148                # Check that both specs are consistent about whether mem_obj_id is defined
149                if (lhs_spec.mem_obj_id is None) != (rhs_spec.mem_obj_id is None):
150                    raise InternalError(
151                        "Specs do not agree on whether mem_obj_id is defined."
152                    )
153
154                has_storage_overlap = Verifier.storage_overlap(lhs_spec, rhs_spec)
155                if not has_storage_overlap:
156                    continue
157
158                if not allow_lifetime_and_storage_overlap and self.lifetime_overlap(
159                    lhs_spec, rhs_spec
160                ):
161                    raise InternalError(
162                        f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}"
163                    )
164
165                # Check that each mem_obj_id is consistent with whether the tensors have
166                # storage overlap
167                if not Verifier.mem_obj_id_match(lhs_spec, rhs_spec):
168                    raise InternalError(
169                        f"Unexpected mem_obj_id mismatch: lhs {lhs_spec}, rhs {rhs_spec}"
170                    )
171
172                num_reuse_pairs += 1
173
174        return num_reuse_pairs
175
176    def verify_graph_input_output(self) -> None:
177        r"""
178        alloc_graph_input / alloc_graph_output indicas if memory for graph
179        input/output is allocated by the compiler. If not, the runtime will
180        set them using buffers provided by users.
181        """
182        graph_module = self.graph_module
183        # There is one tricky case here. If the graph input and graph output
184        # tensors have overlap, but alloc_graph_input != alloc_graph_output,
185        # then the overlapped tensor will cause assertion failure below.
186        # The current behavior is if either alloc_graph_input or alloc_graph_output
187        # is false, those overlapped tensor will not have memory allocated.
188        #
189        # Ignore the check in this case for now.
190        overlap = get_graph_input_tensors(
191            graph_module.graph.nodes, self.graph_signature
192        ) & get_graph_output_tensors(graph_module.graph.nodes)
193        if overlap and (self.alloc_graph_input != self.alloc_graph_output):
194            logging.debug(
195                "Having overlapping graph input/output tensors while the allocation decision for graph input/output mismatch."
196            )
197            return
198
199        graph_input_allocated = None
200        graph_output_allocated = None
201
202        has_dynamic_unbound_input = False
203        has_dynamic_unbound_output = False
204
205        check_list = {"placeholder", "output"} & {
206            node.op for node in graph_module.graph.nodes
207        }
208        assert "output" in check_list, f"graph module has no output: {graph_module}"
209
210        for nd in graph_module.graph.nodes:
211            if nd.op in check_list:
212                if not (specs := get_node_tensor_specs(nd)):
213                    continue
214                if _is_mutable_buffer(nd, self.graph_signature):
215                    continue
216                assert len(specs) > 0, "Expect tensor specs"
217                specs = list(filter(lambda spec: not spec.const, specs))
218                if len(specs) == 0:
219                    continue
220                allocated = any(
221                    spec is None or spec.mem_offset is not None for spec in specs
222                )
223                has_dynamic_unbound_tensor = any(
224                    spec is None
225                    or spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
226                    for spec in specs
227                )
228                assert (
229                    all(spec is None or spec.mem_offset is not None for spec in specs)
230                    == allocated
231                ), "Either all or non of the tensors should be allocated memory"
232                if nd.op == "placeholder":
233                    graph_input_allocated = allocated
234                    has_dynamic_unbound_input |= has_dynamic_unbound_tensor
235                else:
236                    graph_output_allocated = allocated
237                    has_dynamic_unbound_output |= has_dynamic_unbound_tensor
238
239        if "placeholder" in check_list:
240            assert graph_input_allocated is not None, "graph_input_allocated not set"
241            if not has_dynamic_unbound_input:
242                assert (
243                    graph_input_allocated == self.alloc_graph_input
244                ), f"Misallocate graph input: {graph_input_allocated} v.s. {self.alloc_graph_input}"
245
246        assert graph_output_allocated is not None, "graph_output_allocated not set"
247        if not has_dynamic_unbound_output:
248            assert (
249                graph_output_allocated == self.alloc_graph_output
250            ), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}"
251
252
253def _is_out_var_node(node: torch.fx.Node) -> bool:
254    return (
255        node.op == "call_function"
256        and isinstance(node.target, torch._ops.OpOverload)
257        and is_out_variant(node.target._schema.name, node.target._schema.overload_name)
258    )
259
260
261def _is_inplace_node(node: torch.fx.Node) -> bool:
262    return (
263        node.op == "call_function"
264        and isinstance(node.target, torch._ops.OpOverload)
265        and is_inplace_variant(
266            node.target._schema.name, node.target._schema.overload_name
267        )
268    )
269
270
271def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
272    r"""
273    Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
274    are represented by the index of the first and last node referring
275    that tensor in its inputs/outputs.
276
277    Arguments:
278        spec: the TensorSpec for the tensor
279        node_idx: extend the tensor's lifetime to cover node_idx
280    """
281    start, end = spec.lifetime
282    start = node_idx if start is None or start > node_idx else start
283    end = node_idx if end is None or end < node_idx else end
284    spec.lifetime = [start, end]
285
286
287# pyre-ignore
288def filter_nodes(inputs: Iterable[Any]) -> Iterable[Node]:
289    """
290    This method need return Node object embedded inside List/Dict as well.
291    """
292    return [nd for nd in tree_flatten(list(inputs))[0] if isinstance(nd, Node)]
293
294
295def _is_mutable_buffer(
296    node: Node, graph_signature: Optional[ExportGraphSignature] = None
297) -> bool:
298    """
299    Check if the node is mutable buffer according to the provided graph signature.
300    """
301    # graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them.
302    if graph_signature is None:
303        return False
304    if node.op == "placeholder":
305        if isinstance(node.target, str):
306            if node.target in graph_signature.inputs_to_buffers:
307                fqn = graph_signature.inputs_to_buffers[node.target]
308                # if the buffer is mutated then record that
309                if fqn in graph_signature.buffers_to_mutate.values():
310                    return True
311    return False
312
313
314def get_graph_input_tensors(
315    nodes: Iterable[Node], graph_signature: Optional[ExportGraphSignature] = None
316) -> Set[TensorSpec]:
317    graph_input_tensors = set()
318    for node in nodes:
319        if node.op == "placeholder" and not _is_mutable_buffer(node, graph_signature):
320            for spec in get_node_tensor_specs(node):
321                graph_input_tensors.add(spec)
322
323    return graph_input_tensors
324
325
326def get_graph_output_tensors(nodes: Iterable[Node]) -> Set[TensorSpec]:
327    graph_output_tensors = set()
328    for node in nodes:
329        if node.op == "output":
330            for spec in get_node_tensor_specs(node):
331                graph_output_tensors.add(spec)
332
333    return graph_output_tensors
334
335
336def collect_specs_from_nodes(  # noqa: C901
337    nodes: Iterable[Node],
338    graph_signature: Optional[ExportGraphSignature] = None,
339    ignore_graph_input: bool = False,
340    ignore_graph_output: bool = False,
341    ignore_const: bool = True,
342    ignore_out_var_node: bool = True,
343    dedup: bool = True,
344    do_assertion: bool = True,
345    ignore_dynamic_unbound_tensor: bool = True,
346) -> Iterable[TensorSpec]:
347    r"""
348    Collect specs from the passed in nodes. Do filtering as controlled by
349    arguments.
350    Arguments:
351        ignore_graph_input: ignore graph input tensors from placeholder nodes
352        ignore_const: whether to ignore the const
353        ignore_out_var_node: whether to ignore out variant node
354        dedup: whether do dedup
355        do_assertion: whether to assert the filtered nodes belong to a resticted set like alloc, getitem
356    """
357    unique_spec = set()
358    graph_input_tensors: Set[TensorSpec] = (
359        get_graph_input_tensors(nodes, graph_signature) if ignore_graph_input else set()
360    )
361    graph_output_tensors: Set[TensorSpec] = (
362        get_graph_output_tensors(nodes) if ignore_graph_output else set()
363    )
364
365    for node in nodes:
366        # ignore the specs from unrelevant Fx ops for now.
367        if node.op in ["get_attr"]:
368            continue
369
370        # don't reallocate memory for out-variant op's output tensors,
371        # since they are just input tenors.
372        if ignore_out_var_node and _is_out_var_node(node):
373            continue
374
375        if not (specs := get_node_tensor_specs(node)):
376            continue
377
378        if _is_inplace_node(node):
379            continue
380
381        if do_assertion:
382            internal_assert(
383                node.op in ("placeholder", "output")
384                or node.target
385                in [
386                    memory.alloc,
387                    memory.view,
388                    operator.getitem,
389                    torch.ops.higher_order.cond,
390                    exir_while,
391                    torch.ops.higher_order.map_impl,
392                    executorch_call_delegate,
393                ],
394                f"Unexpected op {node.op}, target {node.target}",
395            )
396        for spec in specs:
397            if spec is None:
398                continue
399            # Dynamic unbound tensors' memory will be allocated by the runtime.
400            # Memory planning should ignore them.
401            if (
402                ignore_dynamic_unbound_tensor
403                and spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
404            ):
405                continue
406
407            # Note: graph input may be the output of other ops (e.g. the return op)
408            # If ignore_graph_input is true, we should ignore those Tensor so
409            # we skip planning memory for graph input.
410            if ignore_graph_input and spec in graph_input_tensors:
411                continue
412            if ignore_graph_output and spec in graph_output_tensors:
413                continue
414            if (
415                ignore_const
416                and spec.const
417                and not node.meta.get("weight_has_gradient", False)
418            ):
419                continue
420            if dedup:
421                if spec in unique_spec:
422                    continue
423                else:
424                    unique_spec.add(spec)
425            yield spec
426
427
428def update_all_tensors_lifetime(
429    graph_module: torch.fx.GraphModule,
430    graph_signature: Optional[ExportGraphSignature] = None,
431) -> Set[TensorSpec]:
432    r"""
433    Set the lifetime for all the tensors encountered in the Fx graph.
434    """
435    specs = set()
436    for node_idx, node in enumerate(graph_module.graph.nodes):
437        for spec in collect_specs_from_nodes(
438            filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
439            graph_signature,
440            ignore_graph_input=False,
441            ignore_const=False,
442            ignore_out_var_node=False,
443            dedup=False,
444            do_assertion=False,
445            ignore_dynamic_unbound_tensor=False,
446        ):
447            update_tensor_lifetime(spec, node_idx)
448            specs.add(spec)
449    return specs
450
451
452@dataclass
453class SharedObject:
454    r"""
455    We define the concept of shared object, which represents a segment
456    in the memory buffer that can be shared by multiple tensors. In order to
457    check if a shared object is available for a tensor, we maintain the
458    last_used_index attribute. The shared object will be available for nodes
459    with index greater than last_used_index.
460    """
461
462    # index of the shared object in the list of shared objects, used as a unique id
463    idx: int
464    # offset in the memory buffer
465    offset: int
466    # size of this shared object in bytes
467    size: int
468    # the object will be available for index (last_used_index + 1)
469    last_used_index: int
470
471
472def materialize_buffer(
473    shared_objects: List[SharedObject], input_total_size: int = 0
474) -> int:
475    r"""
476    Assign concrete location in the buffer for each SharedObject.offset.
477
478    Assuming all the passed in shared objects belong to the same memory buffer.
479    """
480    total_size = input_total_size
481    for sobj in shared_objects:
482        sobj.offset = total_size
483        total_size += sobj.size
484    return total_size
485
486
487def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int:
488    r"""
489    Calculate the absolute different between the size of a shared object and
490    a tensor.
491    """
492    return abs(sobj.size - spec.allocated_memory)
493
494
495def pick_shared_obj(
496    shared_objects: List[SharedObject], spec: TensorSpec
497) -> SharedObject:
498    r"""
499    Pick the available shared object with closest size to the tensor.
500    If there are no available shared object left, create a new one.
501    """
502    # TODO: do better than linear scan
503    picked = None
504    for sobj in shared_objects:
505        if spec.lifetime[0] > sobj.last_used_index:
506            if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif(
507                picked, spec
508            ):
509                picked = sobj
510                sobj.last_used_index = spec.lifetime[1]
511                sobj.size = max(sobj.size, spec.allocated_memory)
512    if picked is None:
513        picked = SharedObject(
514            len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
515        )
516        shared_objects.append(picked)
517
518    return picked
519
520
521def get_node_tensor_specs(
522    node: torch.fx.Node,
523) -> Union[List[TensorSpec], Tuple[TensorSpec]]:
524    r"""
525    Return the list of the tensor specs for the node or empty list if the node
526    has no tensor specs.
527    """
528    # get tensor specs
529    if node.target == memory.view:
530        base = node.args[0]
531        assert isinstance(base, torch.fx.Node)
532        specs = base.meta.get("spec")
533    else:
534        specs = node.meta.get("spec")
535
536    if isinstance(specs, TensorSpec):
537        specs = [specs]
538    if not isinstance(specs, (list, tuple)):
539        return []
540    else:
541        return [
542            spec
543            for spec in specs
544            if not isinstance(spec, (int, float, bool, str, type(None)))
545        ]
546
547
548def greedy(
549    graph_module: torch.fx.GraphModule,
550    alignment: int,
551    graph_signature: Optional[ExportGraphSignature] = None,
552    alloc_graph_input: bool = True,
553    alloc_graph_output: bool = True,
554) -> List[int]:
555    spec2obj = {}
556    shared_objects = defaultdict(list)
557    # Don't do assertion in collect_specs_from_nodes if we have already encountered
558    # and ignored some to_out_variant errors.
559    do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False)
560    # For each tensor, pick the available shared object with closest size to
561    # the tensor. If there are no available shared object left, create a new
562    # one.
563    for spec in collect_specs_from_nodes(
564        graph_module.graph.nodes,
565        graph_signature,
566        do_assertion=do_assertion,
567        ignore_graph_input=not alloc_graph_input,
568        ignore_graph_output=not alloc_graph_output,
569    ):
570        if spec.mem_id is None:
571            spec.mem_id = 1
572        spec.realign(alignment)
573        spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
574
575    if len(shared_objects) == 0:
576        # Cannot find any tensor in the graph that needs to be allocated.
577        # Return [0, 0] to be consistent with default behavior of naive.
578        total_sizes = [0, 0]
579    else:
580        total_sizes = [0] * (max(shared_objects.keys()) + 1)
581        for mem_id in shared_objects:
582            input_total_size = 0
583            if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
584                # pyre-fixme[6]: For 1st argument expected
585                #  `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
586                if len(bufsizes) > mem_id:
587                    # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
588                    input_total_size = bufsizes[mem_id]
589            total_sizes[mem_id] = materialize_buffer(
590                shared_objects[mem_id], input_total_size
591            )
592
593        # Since we now know the number of shared objects we need and the size of
594        # each shared object, we can assign offset in the memory buffer for each
595        # shared object.
596        for spec, sobj in spec2obj.items():
597            spec.mem_obj_id = sobj.idx
598            spec.mem_offset = sobj.offset
599
600    logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
601    return total_sizes
602
603
604def naive(
605    graph_module: torch.fx.GraphModule,
606    alignment: int,
607    graph_signature: Optional[ExportGraphSignature] = None,
608    alloc_graph_input: bool = True,
609    alloc_graph_output: bool = True,
610) -> List[int]:
611
612    # allocate 'allocated' bytes from buffer with id mem_id.
613    # return the starting offset of the allocated buffer.
614    def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
615        if mem_id >= len(bufsizes):
616            bufsizes.extend([0] * (mem_id - len(bufsizes) + 1))
617        ret = bufsizes[mem_id]
618        bufsizes[mem_id] += allocated
619        return ret
620
621    bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
622    if bufsizes is None:
623        bufsizes = [0, 0]
624
625    bufsizes = typing.cast(List[int], bufsizes)
626    for spec in collect_specs_from_nodes(
627        graph_module.graph.nodes,
628        graph_signature,
629        ignore_graph_input=not alloc_graph_input,
630        ignore_graph_output=not alloc_graph_output,
631    ):
632        # assume a single memory layer which has mem_id 1
633        if spec.mem_id is None:
634            spec.mem_id = 1
635        # allocate spec.allocated_memory bytes in the buffer
636        # with the corresponding mem_id
637        spec.realign(alignment)
638        spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
639
640    logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
641    return bufsizes
642
643
644def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
645    for nd in graph_module.graph.nodes:
646        if nd.target is torch.ops.higher_order.cond:
647            yield nd
648
649
650def get_while_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
651    for nd in graph_module.graph.nodes:
652        if nd.target is exir_while:
653            yield nd
654
655
656def get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
657    for nd in graph_module.graph.nodes:
658        if nd.target is torch.ops.higher_order.map_impl:
659            yield nd
660
661
662def get_return_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]:
663    return_specs = set()
664    nodes = graph_module.graph.nodes
665    if len(nodes) > 0:
666        last_node = next(iter(reversed(nodes)))
667        for spec in tree_flatten(last_node.meta["spec"])[0]:
668            return_specs.add(spec)
669    return return_specs
670
671
672def get_input_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]:
673    input_specs = set()
674    nodes = graph_module.graph.nodes
675    for node in nodes:
676        if node.op == "placeholder":
677            for spec in tree_flatten(node.meta["spec"])[0]:
678                input_specs.add(spec)
679    return input_specs
680
681
682def insert_calls_to_free(
683    graph_module: fx.GraphModule, allspecs: Set[TensorSpec]
684) -> None:
685    """
686    Insert calls to free for dynamic unbound tensors that goes out of lifetime.
687
688    Only handle the module itself. Submodule is handles in separate calls of
689    this function.
690
691    NOTE: this method will invalidate lifetime recorded in TensorSpec because
692    of extra free node added to the graph.
693    """
694    # Note: we should never free a output tensor
695    return_specs = get_return_specs(graph_module)
696    # Note: we should never free a input tensor since buffer for input tensor
697    # may be passed in from user.
698    input_specs = get_input_specs(graph_module)
699    idx_to_dead_specs = defaultdict(list)
700    for spec in allspecs:
701        if (
702            spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
703            and spec not in return_specs
704            and spec not in input_specs
705        ):
706            idx_to_dead_specs[spec.lifetime[1]].append(spec)
707
708    num_nodes = len(graph_module.graph.nodes)
709    # iterate in reverse order so inserted node does not disturbe node
710    # numbering.
711    for node, node_idx in zip(
712        reversed(graph_module.graph.nodes), range(num_nodes - 1, -1, -1)
713    ):
714        dead_specs = idx_to_dead_specs.get(node_idx, [])
715        if not dead_specs:
716            continue
717        with graph_module.graph.inserting_after(node):
718            for spec in dead_specs:
719                graph_module.graph.call_function(memory.free, (spec,))
720    graph_module.recompile()
721
722
723def apply_algo(
724    algo: Callable[
725        [torch.fx.GraphModule, int, Optional[ExportGraphSignature], bool, bool],
726        List[int],
727    ],
728    graph_module: torch.fx.GraphModule,
729    alignment: int,
730    graph_signature: Optional[ExportGraphSignature] = None,
731    alloc_graph_input: bool = True,
732    alloc_graph_output: bool = True,
733) -> List[int]:
734    """
735    Recursively apply algo to graph_module and its submodules for control flow.
736
737    Quite naively right now since it does not take the following optimizations
738    into considerating:
739    1. for conditional structure, true branch and false true does not overlap
740       in lifetime and can share tensor storage
741    2. tensors inside a submodule (e.g. true branch) has opportunities to share
742       storage with tensors in the outer module.
743    TODO: make these optimizations once we have some baseline working.
744    """
745    specs = update_all_tensors_lifetime(graph_module, graph_signature)
746    bufsizes: List[int] = algo(
747        graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
748    )
749    insert_calls_to_free(graph_module, specs)
750
751    def handle_submodule(
752        submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
753    ) -> None:
754        nonlocal bufsizes
755        assert submodule_nd.op == "get_attr"
756        submodule = getattr(graph_module, submodule_nd.target)
757        # memory planning for submodule need to be aware of the amount of
758        # buffer already allocated.
759        submodule.input_mem_buffer_sizes = bufsizes
760        bufsizes = apply_algo(
761            algo,
762            submodule,
763            alignment,
764            graph_signature,
765            alloc_graph_input=alloc_graph_input,
766            alloc_graph_output=True,
767        )
768        submodule.meta.update({"non_const_buffer_sizes": bufsizes})
769
770    for cond_node in get_cond_nodes(graph_module):
771        handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
772        handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))
773
774    for while_node in get_while_nodes(graph_module):
775        handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
776        handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
777    # TODO: Add test coverage for map operator once dynamo tracing is
778    # fully supported for this. T142287208
779    for map_node in get_map_nodes(graph_module):
780        handle_submodule(
781            typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
782        )
783
784    graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
785
786    return bufsizes
787