xref: /aosp_15_r20/external/executorch/exir/memory_planning.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport itertools
10*523fa7a6SAndroid Build Coastguard Workerimport logging
11*523fa7a6SAndroid Build Coastguard Workerimport operator
12*523fa7a6SAndroid Build Coastguard Workerimport typing
13*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict
14*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
15*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerimport torch
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import memory
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.control_flow import while_loop as exir_while
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import internal_assert, InternalError
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.operator.convert import is_inplace_variant, is_out_variant
23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import TensorShapeDynamism
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerfrom torch import fx
27*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportGraphSignature
28*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import Node
29*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard WorkerREGISTERED_ALGOS: Dict[str, Callable[..., List[int]]] = {}
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Workerclass Verifier:
35*523fa7a6SAndroid Build Coastguard Worker    """
36*523fa7a6SAndroid Build Coastguard Worker    Verify if the outcome of a memory planning algorithm makes sense.
37*523fa7a6SAndroid Build Coastguard Worker    E.g., make sure tensors having overlapping lifetime does not have overlapping
38*523fa7a6SAndroid Build Coastguard Worker    storage/buffer.
39*523fa7a6SAndroid Build Coastguard Worker    """
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker    def __init__(
42*523fa7a6SAndroid Build Coastguard Worker        self,
43*523fa7a6SAndroid Build Coastguard Worker        graph_module: torch.fx.GraphModule,
44*523fa7a6SAndroid Build Coastguard Worker        alloc_graph_input: bool,
45*523fa7a6SAndroid Build Coastguard Worker        alloc_graph_output: bool,
46*523fa7a6SAndroid Build Coastguard Worker        graph_signature: Optional[ExportGraphSignature] = None,
47*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
48*523fa7a6SAndroid Build Coastguard Worker        self.graph_module = graph_module
49*523fa7a6SAndroid Build Coastguard Worker        self.graph_signature = graph_signature
50*523fa7a6SAndroid Build Coastguard Worker        self.alloc_graph_input = alloc_graph_input
51*523fa7a6SAndroid Build Coastguard Worker        self.alloc_graph_output = alloc_graph_output
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    @classmethod
54*523fa7a6SAndroid Build Coastguard Worker    def mem_obj_id_match(
55*523fa7a6SAndroid Build Coastguard Worker        cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec, accept_both_none: bool = True
56*523fa7a6SAndroid Build Coastguard Worker    ) -> bool:
57*523fa7a6SAndroid Build Coastguard Worker        """
58*523fa7a6SAndroid Build Coastguard Worker        Given two `TensorSpec`, return if their `mem_obj_id` are the same. Note that if
59*523fa7a6SAndroid Build Coastguard Worker        both are None, this function will return True if `accept_both_none` is True and
60*523fa7a6SAndroid Build Coastguard Worker        False otherwise.
61*523fa7a6SAndroid Build Coastguard Worker        """
62*523fa7a6SAndroid Build Coastguard Worker        if lhs_spec.mem_id != rhs_spec.mem_id:
63*523fa7a6SAndroid Build Coastguard Worker            return False
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker        # both are None
66*523fa7a6SAndroid Build Coastguard Worker        if lhs_spec.mem_obj_id is None and rhs_spec.mem_obj_id is None:
67*523fa7a6SAndroid Build Coastguard Worker            return accept_both_none
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker        return lhs_spec.mem_obj_id == rhs_spec.mem_obj_id
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker    @classmethod
72*523fa7a6SAndroid Build Coastguard Worker    def has_overlap(cls, lhs_ivl: List[int], rhs_ivl: List[int]) -> bool:
73*523fa7a6SAndroid Build Coastguard Worker        r"""
74*523fa7a6SAndroid Build Coastguard Worker        The passed in intervals are inclusive in both sides. Return if they have
75*523fa7a6SAndroid Build Coastguard Worker        overlapping.
76*523fa7a6SAndroid Build Coastguard Worker        """
77*523fa7a6SAndroid Build Coastguard Worker        # empty interval
78*523fa7a6SAndroid Build Coastguard Worker        if lhs_ivl[0] > lhs_ivl[1] or rhs_ivl[0] > rhs_ivl[1]:
79*523fa7a6SAndroid Build Coastguard Worker            return False
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker        return (lhs_ivl[0] >= rhs_ivl[0] and lhs_ivl[0] <= rhs_ivl[1]) or (
82*523fa7a6SAndroid Build Coastguard Worker            rhs_ivl[0] >= lhs_ivl[0] and rhs_ivl[0] <= lhs_ivl[1]
83*523fa7a6SAndroid Build Coastguard Worker        )
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker    @classmethod
86*523fa7a6SAndroid Build Coastguard Worker    def lifetime_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
87*523fa7a6SAndroid Build Coastguard Worker        lhs_lifetime = lhs_spec.lifetime
88*523fa7a6SAndroid Build Coastguard Worker        rhs_lifetime = rhs_spec.lifetime
89*523fa7a6SAndroid Build Coastguard Worker        internal_assert(
90*523fa7a6SAndroid Build Coastguard Worker            lhs_lifetime[0] is not None and lhs_lifetime[1] is not None,
91*523fa7a6SAndroid Build Coastguard Worker            f"{lhs_spec} should have valid start and end",
92*523fa7a6SAndroid Build Coastguard Worker        )
93*523fa7a6SAndroid Build Coastguard Worker        internal_assert(
94*523fa7a6SAndroid Build Coastguard Worker            rhs_lifetime[0] is not None and rhs_lifetime[1] is not None,
95*523fa7a6SAndroid Build Coastguard Worker            f"{rhs_spec} should have valid start and end",
96*523fa7a6SAndroid Build Coastguard Worker        )
97*523fa7a6SAndroid Build Coastguard Worker        return cls.has_overlap(lhs_lifetime, rhs_lifetime)
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    @classmethod
100*523fa7a6SAndroid Build Coastguard Worker    def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
101*523fa7a6SAndroid Build Coastguard Worker        intervals = []
102*523fa7a6SAndroid Build Coastguard Worker        if lhs_spec.mem_id != rhs_spec.mem_id:
103*523fa7a6SAndroid Build Coastguard Worker            return False
104*523fa7a6SAndroid Build Coastguard Worker        for spec in [lhs_spec, rhs_spec]:
105*523fa7a6SAndroid Build Coastguard Worker            internal_assert(
106*523fa7a6SAndroid Build Coastguard Worker                spec.allocated_memory >= 0,
107*523fa7a6SAndroid Build Coastguard Worker                f"{spec} should have non-zero allocated memory",
108*523fa7a6SAndroid Build Coastguard Worker            )
109*523fa7a6SAndroid Build Coastguard Worker            internal_assert(
110*523fa7a6SAndroid Build Coastguard Worker                isinstance(spec.mem_offset, int) and spec.mem_offset >= 0,
111*523fa7a6SAndroid Build Coastguard Worker                f"{spec} should have specified memory offset",
112*523fa7a6SAndroid Build Coastguard Worker            )
113*523fa7a6SAndroid Build Coastguard Worker            intervals.append(
114*523fa7a6SAndroid Build Coastguard Worker                [spec.mem_offset, spec.mem_offset + spec.allocated_memory - 1]
115*523fa7a6SAndroid Build Coastguard Worker            )
116*523fa7a6SAndroid Build Coastguard Worker        has_overlap = cls.has_overlap(*intervals)
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker        return has_overlap
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker    def verify_storage_reuse(
121*523fa7a6SAndroid Build Coastguard Worker        self, allow_lifetime_and_storage_overlap: bool = False
122*523fa7a6SAndroid Build Coastguard Worker    ) -> int:
123*523fa7a6SAndroid Build Coastguard Worker        """
124*523fa7a6SAndroid Build Coastguard Worker        'allow_lifetime_and_storage_overlap' allows tensors to overlap in both
125*523fa7a6SAndroid Build Coastguard Worker        lifetime and storage. If is it False, and two tensors have both overlapping
126*523fa7a6SAndroid Build Coastguard Worker        lifetime and storage, throw an exception.
127*523fa7a6SAndroid Build Coastguard Worker        Returns:
128*523fa7a6SAndroid Build Coastguard Worker            Number of pairs of tenors that have overlapping storage.
129*523fa7a6SAndroid Build Coastguard Worker        """
130*523fa7a6SAndroid Build Coastguard Worker        num_reuse_pairs = 0
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker        # unique tensors specs
133*523fa7a6SAndroid Build Coastguard Worker        all_specs = list(
134*523fa7a6SAndroid Build Coastguard Worker            collect_specs_from_nodes(
135*523fa7a6SAndroid Build Coastguard Worker                self.graph_module.graph.nodes,
136*523fa7a6SAndroid Build Coastguard Worker                self.graph_signature,
137*523fa7a6SAndroid Build Coastguard Worker                ignore_const=True,
138*523fa7a6SAndroid Build Coastguard Worker                ignore_graph_input=not self.alloc_graph_input,
139*523fa7a6SAndroid Build Coastguard Worker                ignore_graph_output=not self.alloc_graph_output,
140*523fa7a6SAndroid Build Coastguard Worker                do_assertion=False,
141*523fa7a6SAndroid Build Coastguard Worker                ignore_out_var_node=False,
142*523fa7a6SAndroid Build Coastguard Worker                dedup=True,
143*523fa7a6SAndroid Build Coastguard Worker            )
144*523fa7a6SAndroid Build Coastguard Worker        )
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker        for lhs_spec_idx, lhs_spec in enumerate(all_specs):
147*523fa7a6SAndroid Build Coastguard Worker            for rhs_spec in all_specs[lhs_spec_idx + 1 :]:
148*523fa7a6SAndroid Build Coastguard Worker                # Check that both specs are consistent about whether mem_obj_id is defined
149*523fa7a6SAndroid Build Coastguard Worker                if (lhs_spec.mem_obj_id is None) != (rhs_spec.mem_obj_id is None):
150*523fa7a6SAndroid Build Coastguard Worker                    raise InternalError(
151*523fa7a6SAndroid Build Coastguard Worker                        "Specs do not agree on whether mem_obj_id is defined."
152*523fa7a6SAndroid Build Coastguard Worker                    )
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker                has_storage_overlap = Verifier.storage_overlap(lhs_spec, rhs_spec)
155*523fa7a6SAndroid Build Coastguard Worker                if not has_storage_overlap:
156*523fa7a6SAndroid Build Coastguard Worker                    continue
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker                if not allow_lifetime_and_storage_overlap and self.lifetime_overlap(
159*523fa7a6SAndroid Build Coastguard Worker                    lhs_spec, rhs_spec
160*523fa7a6SAndroid Build Coastguard Worker                ):
161*523fa7a6SAndroid Build Coastguard Worker                    raise InternalError(
162*523fa7a6SAndroid Build Coastguard Worker                        f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}"
163*523fa7a6SAndroid Build Coastguard Worker                    )
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker                # Check that each mem_obj_id is consistent with whether the tensors have
166*523fa7a6SAndroid Build Coastguard Worker                # storage overlap
167*523fa7a6SAndroid Build Coastguard Worker                if not Verifier.mem_obj_id_match(lhs_spec, rhs_spec):
168*523fa7a6SAndroid Build Coastguard Worker                    raise InternalError(
169*523fa7a6SAndroid Build Coastguard Worker                        f"Unexpected mem_obj_id mismatch: lhs {lhs_spec}, rhs {rhs_spec}"
170*523fa7a6SAndroid Build Coastguard Worker                    )
171*523fa7a6SAndroid Build Coastguard Worker
172*523fa7a6SAndroid Build Coastguard Worker                num_reuse_pairs += 1
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Worker        return num_reuse_pairs
175*523fa7a6SAndroid Build Coastguard Worker
176*523fa7a6SAndroid Build Coastguard Worker    def verify_graph_input_output(self) -> None:
177*523fa7a6SAndroid Build Coastguard Worker        r"""
178*523fa7a6SAndroid Build Coastguard Worker        alloc_graph_input / alloc_graph_output indicas if memory for graph
179*523fa7a6SAndroid Build Coastguard Worker        input/output is allocated by the compiler. If not, the runtime will
180*523fa7a6SAndroid Build Coastguard Worker        set them using buffers provided by users.
181*523fa7a6SAndroid Build Coastguard Worker        """
182*523fa7a6SAndroid Build Coastguard Worker        graph_module = self.graph_module
183*523fa7a6SAndroid Build Coastguard Worker        # There is one tricky case here. If the graph input and graph output
184*523fa7a6SAndroid Build Coastguard Worker        # tensors have overlap, but alloc_graph_input != alloc_graph_output,
185*523fa7a6SAndroid Build Coastguard Worker        # then the overlapped tensor will cause assertion failure below.
186*523fa7a6SAndroid Build Coastguard Worker        # The current behavior is if either alloc_graph_input or alloc_graph_output
187*523fa7a6SAndroid Build Coastguard Worker        # is false, those overlapped tensor will not have memory allocated.
188*523fa7a6SAndroid Build Coastguard Worker        #
189*523fa7a6SAndroid Build Coastguard Worker        # Ignore the check in this case for now.
190*523fa7a6SAndroid Build Coastguard Worker        overlap = get_graph_input_tensors(
191*523fa7a6SAndroid Build Coastguard Worker            graph_module.graph.nodes, self.graph_signature
192*523fa7a6SAndroid Build Coastguard Worker        ) & get_graph_output_tensors(graph_module.graph.nodes)
193*523fa7a6SAndroid Build Coastguard Worker        if overlap and (self.alloc_graph_input != self.alloc_graph_output):
194*523fa7a6SAndroid Build Coastguard Worker            logging.debug(
195*523fa7a6SAndroid Build Coastguard Worker                "Having overlapping graph input/output tensors while the allocation decision for graph input/output mismatch."
196*523fa7a6SAndroid Build Coastguard Worker            )
197*523fa7a6SAndroid Build Coastguard Worker            return
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker        graph_input_allocated = None
200*523fa7a6SAndroid Build Coastguard Worker        graph_output_allocated = None
201*523fa7a6SAndroid Build Coastguard Worker
202*523fa7a6SAndroid Build Coastguard Worker        has_dynamic_unbound_input = False
203*523fa7a6SAndroid Build Coastguard Worker        has_dynamic_unbound_output = False
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker        check_list = {"placeholder", "output"} & {
206*523fa7a6SAndroid Build Coastguard Worker            node.op for node in graph_module.graph.nodes
207*523fa7a6SAndroid Build Coastguard Worker        }
208*523fa7a6SAndroid Build Coastguard Worker        assert "output" in check_list, f"graph module has no output: {graph_module}"
209*523fa7a6SAndroid Build Coastguard Worker
210*523fa7a6SAndroid Build Coastguard Worker        for nd in graph_module.graph.nodes:
211*523fa7a6SAndroid Build Coastguard Worker            if nd.op in check_list:
212*523fa7a6SAndroid Build Coastguard Worker                if not (specs := get_node_tensor_specs(nd)):
213*523fa7a6SAndroid Build Coastguard Worker                    continue
214*523fa7a6SAndroid Build Coastguard Worker                if _is_mutable_buffer(nd, self.graph_signature):
215*523fa7a6SAndroid Build Coastguard Worker                    continue
216*523fa7a6SAndroid Build Coastguard Worker                assert len(specs) > 0, "Expect tensor specs"
217*523fa7a6SAndroid Build Coastguard Worker                specs = list(filter(lambda spec: not spec.const, specs))
218*523fa7a6SAndroid Build Coastguard Worker                if len(specs) == 0:
219*523fa7a6SAndroid Build Coastguard Worker                    continue
220*523fa7a6SAndroid Build Coastguard Worker                allocated = any(
221*523fa7a6SAndroid Build Coastguard Worker                    spec is None or spec.mem_offset is not None for spec in specs
222*523fa7a6SAndroid Build Coastguard Worker                )
223*523fa7a6SAndroid Build Coastguard Worker                has_dynamic_unbound_tensor = any(
224*523fa7a6SAndroid Build Coastguard Worker                    spec is None
225*523fa7a6SAndroid Build Coastguard Worker                    or spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
226*523fa7a6SAndroid Build Coastguard Worker                    for spec in specs
227*523fa7a6SAndroid Build Coastguard Worker                )
228*523fa7a6SAndroid Build Coastguard Worker                assert (
229*523fa7a6SAndroid Build Coastguard Worker                    all(spec is None or spec.mem_offset is not None for spec in specs)
230*523fa7a6SAndroid Build Coastguard Worker                    == allocated
231*523fa7a6SAndroid Build Coastguard Worker                ), "Either all or non of the tensors should be allocated memory"
232*523fa7a6SAndroid Build Coastguard Worker                if nd.op == "placeholder":
233*523fa7a6SAndroid Build Coastguard Worker                    graph_input_allocated = allocated
234*523fa7a6SAndroid Build Coastguard Worker                    has_dynamic_unbound_input |= has_dynamic_unbound_tensor
235*523fa7a6SAndroid Build Coastguard Worker                else:
236*523fa7a6SAndroid Build Coastguard Worker                    graph_output_allocated = allocated
237*523fa7a6SAndroid Build Coastguard Worker                    has_dynamic_unbound_output |= has_dynamic_unbound_tensor
238*523fa7a6SAndroid Build Coastguard Worker
239*523fa7a6SAndroid Build Coastguard Worker        if "placeholder" in check_list:
240*523fa7a6SAndroid Build Coastguard Worker            assert graph_input_allocated is not None, "graph_input_allocated not set"
241*523fa7a6SAndroid Build Coastguard Worker            if not has_dynamic_unbound_input:
242*523fa7a6SAndroid Build Coastguard Worker                assert (
243*523fa7a6SAndroid Build Coastguard Worker                    graph_input_allocated == self.alloc_graph_input
244*523fa7a6SAndroid Build Coastguard Worker                ), f"Misallocate graph input: {graph_input_allocated} v.s. {self.alloc_graph_input}"
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker        assert graph_output_allocated is not None, "graph_output_allocated not set"
247*523fa7a6SAndroid Build Coastguard Worker        if not has_dynamic_unbound_output:
248*523fa7a6SAndroid Build Coastguard Worker            assert (
249*523fa7a6SAndroid Build Coastguard Worker                graph_output_allocated == self.alloc_graph_output
250*523fa7a6SAndroid Build Coastguard Worker            ), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}"
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Workerdef _is_out_var_node(node: torch.fx.Node) -> bool:
254*523fa7a6SAndroid Build Coastguard Worker    return (
255*523fa7a6SAndroid Build Coastguard Worker        node.op == "call_function"
256*523fa7a6SAndroid Build Coastguard Worker        and isinstance(node.target, torch._ops.OpOverload)
257*523fa7a6SAndroid Build Coastguard Worker        and is_out_variant(node.target._schema.name, node.target._schema.overload_name)
258*523fa7a6SAndroid Build Coastguard Worker    )
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Workerdef _is_inplace_node(node: torch.fx.Node) -> bool:
262*523fa7a6SAndroid Build Coastguard Worker    return (
263*523fa7a6SAndroid Build Coastguard Worker        node.op == "call_function"
264*523fa7a6SAndroid Build Coastguard Worker        and isinstance(node.target, torch._ops.OpOverload)
265*523fa7a6SAndroid Build Coastguard Worker        and is_inplace_variant(
266*523fa7a6SAndroid Build Coastguard Worker            node.target._schema.name, node.target._schema.overload_name
267*523fa7a6SAndroid Build Coastguard Worker        )
268*523fa7a6SAndroid Build Coastguard Worker    )
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker
271*523fa7a6SAndroid Build Coastguard Workerdef update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
272*523fa7a6SAndroid Build Coastguard Worker    r"""
273*523fa7a6SAndroid Build Coastguard Worker    Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
274*523fa7a6SAndroid Build Coastguard Worker    are represented by the index of the first and last node referring
275*523fa7a6SAndroid Build Coastguard Worker    that tensor in its inputs/outputs.
276*523fa7a6SAndroid Build Coastguard Worker
277*523fa7a6SAndroid Build Coastguard Worker    Arguments:
278*523fa7a6SAndroid Build Coastguard Worker        spec: the TensorSpec for the tensor
279*523fa7a6SAndroid Build Coastguard Worker        node_idx: extend the tensor's lifetime to cover node_idx
280*523fa7a6SAndroid Build Coastguard Worker    """
281*523fa7a6SAndroid Build Coastguard Worker    start, end = spec.lifetime
282*523fa7a6SAndroid Build Coastguard Worker    start = node_idx if start is None or start > node_idx else start
283*523fa7a6SAndroid Build Coastguard Worker    end = node_idx if end is None or end < node_idx else end
284*523fa7a6SAndroid Build Coastguard Worker    spec.lifetime = [start, end]
285*523fa7a6SAndroid Build Coastguard Worker
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore
288*523fa7a6SAndroid Build Coastguard Workerdef filter_nodes(inputs: Iterable[Any]) -> Iterable[Node]:
289*523fa7a6SAndroid Build Coastguard Worker    """
290*523fa7a6SAndroid Build Coastguard Worker    This method need return Node object embedded inside List/Dict as well.
291*523fa7a6SAndroid Build Coastguard Worker    """
292*523fa7a6SAndroid Build Coastguard Worker    return [nd for nd in tree_flatten(list(inputs))[0] if isinstance(nd, Node)]
293*523fa7a6SAndroid Build Coastguard Worker
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Workerdef _is_mutable_buffer(
296*523fa7a6SAndroid Build Coastguard Worker    node: Node, graph_signature: Optional[ExportGraphSignature] = None
297*523fa7a6SAndroid Build Coastguard Worker) -> bool:
298*523fa7a6SAndroid Build Coastguard Worker    """
299*523fa7a6SAndroid Build Coastguard Worker    Check if the node is mutable buffer according to the provided graph signature.
300*523fa7a6SAndroid Build Coastguard Worker    """
301*523fa7a6SAndroid Build Coastguard Worker    # graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them.
302*523fa7a6SAndroid Build Coastguard Worker    if graph_signature is None:
303*523fa7a6SAndroid Build Coastguard Worker        return False
304*523fa7a6SAndroid Build Coastguard Worker    if node.op == "placeholder":
305*523fa7a6SAndroid Build Coastguard Worker        if isinstance(node.target, str):
306*523fa7a6SAndroid Build Coastguard Worker            if node.target in graph_signature.inputs_to_buffers:
307*523fa7a6SAndroid Build Coastguard Worker                fqn = graph_signature.inputs_to_buffers[node.target]
308*523fa7a6SAndroid Build Coastguard Worker                # if the buffer is mutated then record that
309*523fa7a6SAndroid Build Coastguard Worker                if fqn in graph_signature.buffers_to_mutate.values():
310*523fa7a6SAndroid Build Coastguard Worker                    return True
311*523fa7a6SAndroid Build Coastguard Worker    return False
312*523fa7a6SAndroid Build Coastguard Worker
313*523fa7a6SAndroid Build Coastguard Worker
314*523fa7a6SAndroid Build Coastguard Workerdef get_graph_input_tensors(
315*523fa7a6SAndroid Build Coastguard Worker    nodes: Iterable[Node], graph_signature: Optional[ExportGraphSignature] = None
316*523fa7a6SAndroid Build Coastguard Worker) -> Set[TensorSpec]:
317*523fa7a6SAndroid Build Coastguard Worker    graph_input_tensors = set()
318*523fa7a6SAndroid Build Coastguard Worker    for node in nodes:
319*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder" and not _is_mutable_buffer(node, graph_signature):
320*523fa7a6SAndroid Build Coastguard Worker            for spec in get_node_tensor_specs(node):
321*523fa7a6SAndroid Build Coastguard Worker                graph_input_tensors.add(spec)
322*523fa7a6SAndroid Build Coastguard Worker
323*523fa7a6SAndroid Build Coastguard Worker    return graph_input_tensors
324*523fa7a6SAndroid Build Coastguard Worker
325*523fa7a6SAndroid Build Coastguard Worker
326*523fa7a6SAndroid Build Coastguard Workerdef get_graph_output_tensors(nodes: Iterable[Node]) -> Set[TensorSpec]:
327*523fa7a6SAndroid Build Coastguard Worker    graph_output_tensors = set()
328*523fa7a6SAndroid Build Coastguard Worker    for node in nodes:
329*523fa7a6SAndroid Build Coastguard Worker        if node.op == "output":
330*523fa7a6SAndroid Build Coastguard Worker            for spec in get_node_tensor_specs(node):
331*523fa7a6SAndroid Build Coastguard Worker                graph_output_tensors.add(spec)
332*523fa7a6SAndroid Build Coastguard Worker
333*523fa7a6SAndroid Build Coastguard Worker    return graph_output_tensors
334*523fa7a6SAndroid Build Coastguard Worker
335*523fa7a6SAndroid Build Coastguard Worker
336*523fa7a6SAndroid Build Coastguard Workerdef collect_specs_from_nodes(  # noqa: C901
337*523fa7a6SAndroid Build Coastguard Worker    nodes: Iterable[Node],
338*523fa7a6SAndroid Build Coastguard Worker    graph_signature: Optional[ExportGraphSignature] = None,
339*523fa7a6SAndroid Build Coastguard Worker    ignore_graph_input: bool = False,
340*523fa7a6SAndroid Build Coastguard Worker    ignore_graph_output: bool = False,
341*523fa7a6SAndroid Build Coastguard Worker    ignore_const: bool = True,
342*523fa7a6SAndroid Build Coastguard Worker    ignore_out_var_node: bool = True,
343*523fa7a6SAndroid Build Coastguard Worker    dedup: bool = True,
344*523fa7a6SAndroid Build Coastguard Worker    do_assertion: bool = True,
345*523fa7a6SAndroid Build Coastguard Worker    ignore_dynamic_unbound_tensor: bool = True,
346*523fa7a6SAndroid Build Coastguard Worker) -> Iterable[TensorSpec]:
347*523fa7a6SAndroid Build Coastguard Worker    r"""
348*523fa7a6SAndroid Build Coastguard Worker    Collect specs from the passed in nodes. Do filtering as controlled by
349*523fa7a6SAndroid Build Coastguard Worker    arguments.
350*523fa7a6SAndroid Build Coastguard Worker    Arguments:
351*523fa7a6SAndroid Build Coastguard Worker        ignore_graph_input: ignore graph input tensors from placeholder nodes
352*523fa7a6SAndroid Build Coastguard Worker        ignore_const: whether to ignore the const
353*523fa7a6SAndroid Build Coastguard Worker        ignore_out_var_node: whether to ignore out variant node
354*523fa7a6SAndroid Build Coastguard Worker        dedup: whether do dedup
355*523fa7a6SAndroid Build Coastguard Worker        do_assertion: whether to assert the filtered nodes belong to a resticted set like alloc, getitem
356*523fa7a6SAndroid Build Coastguard Worker    """
357*523fa7a6SAndroid Build Coastguard Worker    unique_spec = set()
358*523fa7a6SAndroid Build Coastguard Worker    graph_input_tensors: Set[TensorSpec] = (
359*523fa7a6SAndroid Build Coastguard Worker        get_graph_input_tensors(nodes, graph_signature) if ignore_graph_input else set()
360*523fa7a6SAndroid Build Coastguard Worker    )
361*523fa7a6SAndroid Build Coastguard Worker    graph_output_tensors: Set[TensorSpec] = (
362*523fa7a6SAndroid Build Coastguard Worker        get_graph_output_tensors(nodes) if ignore_graph_output else set()
363*523fa7a6SAndroid Build Coastguard Worker    )
364*523fa7a6SAndroid Build Coastguard Worker
365*523fa7a6SAndroid Build Coastguard Worker    for node in nodes:
366*523fa7a6SAndroid Build Coastguard Worker        # ignore the specs from unrelevant Fx ops for now.
367*523fa7a6SAndroid Build Coastguard Worker        if node.op in ["get_attr"]:
368*523fa7a6SAndroid Build Coastguard Worker            continue
369*523fa7a6SAndroid Build Coastguard Worker
370*523fa7a6SAndroid Build Coastguard Worker        # don't reallocate memory for out-variant op's output tensors,
371*523fa7a6SAndroid Build Coastguard Worker        # since they are just input tenors.
372*523fa7a6SAndroid Build Coastguard Worker        if ignore_out_var_node and _is_out_var_node(node):
373*523fa7a6SAndroid Build Coastguard Worker            continue
374*523fa7a6SAndroid Build Coastguard Worker
375*523fa7a6SAndroid Build Coastguard Worker        if not (specs := get_node_tensor_specs(node)):
376*523fa7a6SAndroid Build Coastguard Worker            continue
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Worker        if _is_inplace_node(node):
379*523fa7a6SAndroid Build Coastguard Worker            continue
380*523fa7a6SAndroid Build Coastguard Worker
381*523fa7a6SAndroid Build Coastguard Worker        if do_assertion:
382*523fa7a6SAndroid Build Coastguard Worker            internal_assert(
383*523fa7a6SAndroid Build Coastguard Worker                node.op in ("placeholder", "output")
384*523fa7a6SAndroid Build Coastguard Worker                or node.target
385*523fa7a6SAndroid Build Coastguard Worker                in [
386*523fa7a6SAndroid Build Coastguard Worker                    memory.alloc,
387*523fa7a6SAndroid Build Coastguard Worker                    memory.view,
388*523fa7a6SAndroid Build Coastguard Worker                    operator.getitem,
389*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.higher_order.cond,
390*523fa7a6SAndroid Build Coastguard Worker                    exir_while,
391*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.higher_order.map_impl,
392*523fa7a6SAndroid Build Coastguard Worker                    executorch_call_delegate,
393*523fa7a6SAndroid Build Coastguard Worker                ],
394*523fa7a6SAndroid Build Coastguard Worker                f"Unexpected op {node.op}, target {node.target}",
395*523fa7a6SAndroid Build Coastguard Worker            )
396*523fa7a6SAndroid Build Coastguard Worker        for spec in specs:
397*523fa7a6SAndroid Build Coastguard Worker            if spec is None:
398*523fa7a6SAndroid Build Coastguard Worker                continue
399*523fa7a6SAndroid Build Coastguard Worker            # Dynamic unbound tensors' memory will be allocated by the runtime.
400*523fa7a6SAndroid Build Coastguard Worker            # Memory planning should ignore them.
401*523fa7a6SAndroid Build Coastguard Worker            if (
402*523fa7a6SAndroid Build Coastguard Worker                ignore_dynamic_unbound_tensor
403*523fa7a6SAndroid Build Coastguard Worker                and spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
404*523fa7a6SAndroid Build Coastguard Worker            ):
405*523fa7a6SAndroid Build Coastguard Worker                continue
406*523fa7a6SAndroid Build Coastguard Worker
407*523fa7a6SAndroid Build Coastguard Worker            # Note: graph input may be the output of other ops (e.g. the return op)
408*523fa7a6SAndroid Build Coastguard Worker            # If ignore_graph_input is true, we should ignore those Tensor so
409*523fa7a6SAndroid Build Coastguard Worker            # we skip planning memory for graph input.
410*523fa7a6SAndroid Build Coastguard Worker            if ignore_graph_input and spec in graph_input_tensors:
411*523fa7a6SAndroid Build Coastguard Worker                continue
412*523fa7a6SAndroid Build Coastguard Worker            if ignore_graph_output and spec in graph_output_tensors:
413*523fa7a6SAndroid Build Coastguard Worker                continue
414*523fa7a6SAndroid Build Coastguard Worker            if (
415*523fa7a6SAndroid Build Coastguard Worker                ignore_const
416*523fa7a6SAndroid Build Coastguard Worker                and spec.const
417*523fa7a6SAndroid Build Coastguard Worker                and not node.meta.get("weight_has_gradient", False)
418*523fa7a6SAndroid Build Coastguard Worker            ):
419*523fa7a6SAndroid Build Coastguard Worker                continue
420*523fa7a6SAndroid Build Coastguard Worker            if dedup:
421*523fa7a6SAndroid Build Coastguard Worker                if spec in unique_spec:
422*523fa7a6SAndroid Build Coastguard Worker                    continue
423*523fa7a6SAndroid Build Coastguard Worker                else:
424*523fa7a6SAndroid Build Coastguard Worker                    unique_spec.add(spec)
425*523fa7a6SAndroid Build Coastguard Worker            yield spec
426*523fa7a6SAndroid Build Coastguard Worker
427*523fa7a6SAndroid Build Coastguard Worker
428*523fa7a6SAndroid Build Coastguard Workerdef update_all_tensors_lifetime(
429*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
430*523fa7a6SAndroid Build Coastguard Worker    graph_signature: Optional[ExportGraphSignature] = None,
431*523fa7a6SAndroid Build Coastguard Worker) -> Set[TensorSpec]:
432*523fa7a6SAndroid Build Coastguard Worker    r"""
433*523fa7a6SAndroid Build Coastguard Worker    Set the lifetime for all the tensors encountered in the Fx graph.
434*523fa7a6SAndroid Build Coastguard Worker    """
435*523fa7a6SAndroid Build Coastguard Worker    specs = set()
436*523fa7a6SAndroid Build Coastguard Worker    for node_idx, node in enumerate(graph_module.graph.nodes):
437*523fa7a6SAndroid Build Coastguard Worker        for spec in collect_specs_from_nodes(
438*523fa7a6SAndroid Build Coastguard Worker            filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
439*523fa7a6SAndroid Build Coastguard Worker            graph_signature,
440*523fa7a6SAndroid Build Coastguard Worker            ignore_graph_input=False,
441*523fa7a6SAndroid Build Coastguard Worker            ignore_const=False,
442*523fa7a6SAndroid Build Coastguard Worker            ignore_out_var_node=False,
443*523fa7a6SAndroid Build Coastguard Worker            dedup=False,
444*523fa7a6SAndroid Build Coastguard Worker            do_assertion=False,
445*523fa7a6SAndroid Build Coastguard Worker            ignore_dynamic_unbound_tensor=False,
446*523fa7a6SAndroid Build Coastguard Worker        ):
447*523fa7a6SAndroid Build Coastguard Worker            update_tensor_lifetime(spec, node_idx)
448*523fa7a6SAndroid Build Coastguard Worker            specs.add(spec)
449*523fa7a6SAndroid Build Coastguard Worker    return specs
450*523fa7a6SAndroid Build Coastguard Worker
451*523fa7a6SAndroid Build Coastguard Worker
452*523fa7a6SAndroid Build Coastguard Worker@dataclass
453*523fa7a6SAndroid Build Coastguard Workerclass SharedObject:
454*523fa7a6SAndroid Build Coastguard Worker    r"""
455*523fa7a6SAndroid Build Coastguard Worker    We define the concept of shared object, which represents a segment
456*523fa7a6SAndroid Build Coastguard Worker    in the memory buffer that can be shared by multiple tensors. In order to
457*523fa7a6SAndroid Build Coastguard Worker    check if a shared object is available for a tensor, we maintain the
458*523fa7a6SAndroid Build Coastguard Worker    last_used_index attribute. The shared object will be available for nodes
459*523fa7a6SAndroid Build Coastguard Worker    with index greater than last_used_index.
460*523fa7a6SAndroid Build Coastguard Worker    """
461*523fa7a6SAndroid Build Coastguard Worker
462*523fa7a6SAndroid Build Coastguard Worker    # index of the shared object in the list of shared objects, used as a unique id
463*523fa7a6SAndroid Build Coastguard Worker    idx: int
464*523fa7a6SAndroid Build Coastguard Worker    # offset in the memory buffer
465*523fa7a6SAndroid Build Coastguard Worker    offset: int
466*523fa7a6SAndroid Build Coastguard Worker    # size of this shared object in bytes
467*523fa7a6SAndroid Build Coastguard Worker    size: int
468*523fa7a6SAndroid Build Coastguard Worker    # the object will be available for index (last_used_index + 1)
469*523fa7a6SAndroid Build Coastguard Worker    last_used_index: int
470*523fa7a6SAndroid Build Coastguard Worker
471*523fa7a6SAndroid Build Coastguard Worker
472*523fa7a6SAndroid Build Coastguard Workerdef materialize_buffer(
473*523fa7a6SAndroid Build Coastguard Worker    shared_objects: List[SharedObject], input_total_size: int = 0
474*523fa7a6SAndroid Build Coastguard Worker) -> int:
475*523fa7a6SAndroid Build Coastguard Worker    r"""
476*523fa7a6SAndroid Build Coastguard Worker    Assign concrete location in the buffer for each SharedObject.offset.
477*523fa7a6SAndroid Build Coastguard Worker
478*523fa7a6SAndroid Build Coastguard Worker    Assuming all the passed in shared objects belong to the same memory buffer.
479*523fa7a6SAndroid Build Coastguard Worker    """
480*523fa7a6SAndroid Build Coastguard Worker    total_size = input_total_size
481*523fa7a6SAndroid Build Coastguard Worker    for sobj in shared_objects:
482*523fa7a6SAndroid Build Coastguard Worker        sobj.offset = total_size
483*523fa7a6SAndroid Build Coastguard Worker        total_size += sobj.size
484*523fa7a6SAndroid Build Coastguard Worker    return total_size
485*523fa7a6SAndroid Build Coastguard Worker
486*523fa7a6SAndroid Build Coastguard Worker
487*523fa7a6SAndroid Build Coastguard Workerdef _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int:
488*523fa7a6SAndroid Build Coastguard Worker    r"""
489*523fa7a6SAndroid Build Coastguard Worker    Calculate the absolute different between the size of a shared object and
490*523fa7a6SAndroid Build Coastguard Worker    a tensor.
491*523fa7a6SAndroid Build Coastguard Worker    """
492*523fa7a6SAndroid Build Coastguard Worker    return abs(sobj.size - spec.allocated_memory)
493*523fa7a6SAndroid Build Coastguard Worker
494*523fa7a6SAndroid Build Coastguard Worker
495*523fa7a6SAndroid Build Coastguard Workerdef pick_shared_obj(
496*523fa7a6SAndroid Build Coastguard Worker    shared_objects: List[SharedObject], spec: TensorSpec
497*523fa7a6SAndroid Build Coastguard Worker) -> SharedObject:
498*523fa7a6SAndroid Build Coastguard Worker    r"""
499*523fa7a6SAndroid Build Coastguard Worker    Pick the available shared object with closest size to the tensor.
500*523fa7a6SAndroid Build Coastguard Worker    If there are no available shared object left, create a new one.
501*523fa7a6SAndroid Build Coastguard Worker    """
502*523fa7a6SAndroid Build Coastguard Worker    # TODO: do better than linear scan
503*523fa7a6SAndroid Build Coastguard Worker    picked = None
504*523fa7a6SAndroid Build Coastguard Worker    for sobj in shared_objects:
505*523fa7a6SAndroid Build Coastguard Worker        if spec.lifetime[0] > sobj.last_used_index:
506*523fa7a6SAndroid Build Coastguard Worker            if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif(
507*523fa7a6SAndroid Build Coastguard Worker                picked, spec
508*523fa7a6SAndroid Build Coastguard Worker            ):
509*523fa7a6SAndroid Build Coastguard Worker                picked = sobj
510*523fa7a6SAndroid Build Coastguard Worker                sobj.last_used_index = spec.lifetime[1]
511*523fa7a6SAndroid Build Coastguard Worker                sobj.size = max(sobj.size, spec.allocated_memory)
512*523fa7a6SAndroid Build Coastguard Worker    if picked is None:
513*523fa7a6SAndroid Build Coastguard Worker        picked = SharedObject(
514*523fa7a6SAndroid Build Coastguard Worker            len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
515*523fa7a6SAndroid Build Coastguard Worker        )
516*523fa7a6SAndroid Build Coastguard Worker        shared_objects.append(picked)
517*523fa7a6SAndroid Build Coastguard Worker
518*523fa7a6SAndroid Build Coastguard Worker    return picked
519*523fa7a6SAndroid Build Coastguard Worker
520*523fa7a6SAndroid Build Coastguard Worker
521*523fa7a6SAndroid Build Coastguard Workerdef get_node_tensor_specs(
522*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
523*523fa7a6SAndroid Build Coastguard Worker) -> Union[List[TensorSpec], Tuple[TensorSpec]]:
524*523fa7a6SAndroid Build Coastguard Worker    r"""
525*523fa7a6SAndroid Build Coastguard Worker    Return the list of the tensor specs for the node or empty list if the node
526*523fa7a6SAndroid Build Coastguard Worker    has no tensor specs.
527*523fa7a6SAndroid Build Coastguard Worker    """
528*523fa7a6SAndroid Build Coastguard Worker    # get tensor specs
529*523fa7a6SAndroid Build Coastguard Worker    if node.target == memory.view:
530*523fa7a6SAndroid Build Coastguard Worker        base = node.args[0]
531*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(base, torch.fx.Node)
532*523fa7a6SAndroid Build Coastguard Worker        specs = base.meta.get("spec")
533*523fa7a6SAndroid Build Coastguard Worker    else:
534*523fa7a6SAndroid Build Coastguard Worker        specs = node.meta.get("spec")
535*523fa7a6SAndroid Build Coastguard Worker
536*523fa7a6SAndroid Build Coastguard Worker    if isinstance(specs, TensorSpec):
537*523fa7a6SAndroid Build Coastguard Worker        specs = [specs]
538*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(specs, (list, tuple)):
539*523fa7a6SAndroid Build Coastguard Worker        return []
540*523fa7a6SAndroid Build Coastguard Worker    else:
541*523fa7a6SAndroid Build Coastguard Worker        return [
542*523fa7a6SAndroid Build Coastguard Worker            spec
543*523fa7a6SAndroid Build Coastguard Worker            for spec in specs
544*523fa7a6SAndroid Build Coastguard Worker            if not isinstance(spec, (int, float, bool, str, type(None)))
545*523fa7a6SAndroid Build Coastguard Worker        ]
546*523fa7a6SAndroid Build Coastguard Worker
547*523fa7a6SAndroid Build Coastguard Worker
548*523fa7a6SAndroid Build Coastguard Workerdef greedy(
549*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
550*523fa7a6SAndroid Build Coastguard Worker    alignment: int,
551*523fa7a6SAndroid Build Coastguard Worker    graph_signature: Optional[ExportGraphSignature] = None,
552*523fa7a6SAndroid Build Coastguard Worker    alloc_graph_input: bool = True,
553*523fa7a6SAndroid Build Coastguard Worker    alloc_graph_output: bool = True,
554*523fa7a6SAndroid Build Coastguard Worker) -> List[int]:
555*523fa7a6SAndroid Build Coastguard Worker    spec2obj = {}
556*523fa7a6SAndroid Build Coastguard Worker    shared_objects = defaultdict(list)
557*523fa7a6SAndroid Build Coastguard Worker    # Don't do assertion in collect_specs_from_nodes if we have already encountered
558*523fa7a6SAndroid Build Coastguard Worker    # and ignored some to_out_variant errors.
559*523fa7a6SAndroid Build Coastguard Worker    do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False)
560*523fa7a6SAndroid Build Coastguard Worker    # For each tensor, pick the available shared object with closest size to
561*523fa7a6SAndroid Build Coastguard Worker    # the tensor. If there are no available shared object left, create a new
562*523fa7a6SAndroid Build Coastguard Worker    # one.
563*523fa7a6SAndroid Build Coastguard Worker    for spec in collect_specs_from_nodes(
564*523fa7a6SAndroid Build Coastguard Worker        graph_module.graph.nodes,
565*523fa7a6SAndroid Build Coastguard Worker        graph_signature,
566*523fa7a6SAndroid Build Coastguard Worker        do_assertion=do_assertion,
567*523fa7a6SAndroid Build Coastguard Worker        ignore_graph_input=not alloc_graph_input,
568*523fa7a6SAndroid Build Coastguard Worker        ignore_graph_output=not alloc_graph_output,
569*523fa7a6SAndroid Build Coastguard Worker    ):
570*523fa7a6SAndroid Build Coastguard Worker        if spec.mem_id is None:
571*523fa7a6SAndroid Build Coastguard Worker            spec.mem_id = 1
572*523fa7a6SAndroid Build Coastguard Worker        spec.realign(alignment)
573*523fa7a6SAndroid Build Coastguard Worker        spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
574*523fa7a6SAndroid Build Coastguard Worker
575*523fa7a6SAndroid Build Coastguard Worker    if len(shared_objects) == 0:
576*523fa7a6SAndroid Build Coastguard Worker        # Cannot find any tensor in the graph that needs to be allocated.
577*523fa7a6SAndroid Build Coastguard Worker        # Return [0, 0] to be consistent with default behavior of naive.
578*523fa7a6SAndroid Build Coastguard Worker        total_sizes = [0, 0]
579*523fa7a6SAndroid Build Coastguard Worker    else:
580*523fa7a6SAndroid Build Coastguard Worker        total_sizes = [0] * (max(shared_objects.keys()) + 1)
581*523fa7a6SAndroid Build Coastguard Worker        for mem_id in shared_objects:
582*523fa7a6SAndroid Build Coastguard Worker            input_total_size = 0
583*523fa7a6SAndroid Build Coastguard Worker            if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
584*523fa7a6SAndroid Build Coastguard Worker                # pyre-fixme[6]: For 1st argument expected
585*523fa7a6SAndroid Build Coastguard Worker                #  `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
586*523fa7a6SAndroid Build Coastguard Worker                if len(bufsizes) > mem_id:
587*523fa7a6SAndroid Build Coastguard Worker                    # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
588*523fa7a6SAndroid Build Coastguard Worker                    input_total_size = bufsizes[mem_id]
589*523fa7a6SAndroid Build Coastguard Worker            total_sizes[mem_id] = materialize_buffer(
590*523fa7a6SAndroid Build Coastguard Worker                shared_objects[mem_id], input_total_size
591*523fa7a6SAndroid Build Coastguard Worker            )
592*523fa7a6SAndroid Build Coastguard Worker
593*523fa7a6SAndroid Build Coastguard Worker        # Since we now know the number of shared objects we need and the size of
594*523fa7a6SAndroid Build Coastguard Worker        # each shared object, we can assign offset in the memory buffer for each
595*523fa7a6SAndroid Build Coastguard Worker        # shared object.
596*523fa7a6SAndroid Build Coastguard Worker        for spec, sobj in spec2obj.items():
597*523fa7a6SAndroid Build Coastguard Worker            spec.mem_obj_id = sobj.idx
598*523fa7a6SAndroid Build Coastguard Worker            spec.mem_offset = sobj.offset
599*523fa7a6SAndroid Build Coastguard Worker
600*523fa7a6SAndroid Build Coastguard Worker    logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
601*523fa7a6SAndroid Build Coastguard Worker    return total_sizes
602*523fa7a6SAndroid Build Coastguard Worker
603*523fa7a6SAndroid Build Coastguard Worker
604*523fa7a6SAndroid Build Coastguard Workerdef naive(
605*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
606*523fa7a6SAndroid Build Coastguard Worker    alignment: int,
607*523fa7a6SAndroid Build Coastguard Worker    graph_signature: Optional[ExportGraphSignature] = None,
608*523fa7a6SAndroid Build Coastguard Worker    alloc_graph_input: bool = True,
609*523fa7a6SAndroid Build Coastguard Worker    alloc_graph_output: bool = True,
610*523fa7a6SAndroid Build Coastguard Worker) -> List[int]:
611*523fa7a6SAndroid Build Coastguard Worker
612*523fa7a6SAndroid Build Coastguard Worker    # allocate 'allocated' bytes from buffer with id mem_id.
613*523fa7a6SAndroid Build Coastguard Worker    # return the starting offset of the allocated buffer.
614*523fa7a6SAndroid Build Coastguard Worker    def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
615*523fa7a6SAndroid Build Coastguard Worker        if mem_id >= len(bufsizes):
616*523fa7a6SAndroid Build Coastguard Worker            bufsizes.extend([0] * (mem_id - len(bufsizes) + 1))
617*523fa7a6SAndroid Build Coastguard Worker        ret = bufsizes[mem_id]
618*523fa7a6SAndroid Build Coastguard Worker        bufsizes[mem_id] += allocated
619*523fa7a6SAndroid Build Coastguard Worker        return ret
620*523fa7a6SAndroid Build Coastguard Worker
621*523fa7a6SAndroid Build Coastguard Worker    bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
622*523fa7a6SAndroid Build Coastguard Worker    if bufsizes is None:
623*523fa7a6SAndroid Build Coastguard Worker        bufsizes = [0, 0]
624*523fa7a6SAndroid Build Coastguard Worker
625*523fa7a6SAndroid Build Coastguard Worker    bufsizes = typing.cast(List[int], bufsizes)
626*523fa7a6SAndroid Build Coastguard Worker    for spec in collect_specs_from_nodes(
627*523fa7a6SAndroid Build Coastguard Worker        graph_module.graph.nodes,
628*523fa7a6SAndroid Build Coastguard Worker        graph_signature,
629*523fa7a6SAndroid Build Coastguard Worker        ignore_graph_input=not alloc_graph_input,
630*523fa7a6SAndroid Build Coastguard Worker        ignore_graph_output=not alloc_graph_output,
631*523fa7a6SAndroid Build Coastguard Worker    ):
632*523fa7a6SAndroid Build Coastguard Worker        # assume a single memory layer which has mem_id 1
633*523fa7a6SAndroid Build Coastguard Worker        if spec.mem_id is None:
634*523fa7a6SAndroid Build Coastguard Worker            spec.mem_id = 1
635*523fa7a6SAndroid Build Coastguard Worker        # allocate spec.allocated_memory bytes in the buffer
636*523fa7a6SAndroid Build Coastguard Worker        # with the corresponding mem_id
637*523fa7a6SAndroid Build Coastguard Worker        spec.realign(alignment)
638*523fa7a6SAndroid Build Coastguard Worker        spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
639*523fa7a6SAndroid Build Coastguard Worker
640*523fa7a6SAndroid Build Coastguard Worker    logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
641*523fa7a6SAndroid Build Coastguard Worker    return bufsizes
642*523fa7a6SAndroid Build Coastguard Worker
643*523fa7a6SAndroid Build Coastguard Worker
644*523fa7a6SAndroid Build Coastguard Workerdef get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
645*523fa7a6SAndroid Build Coastguard Worker    for nd in graph_module.graph.nodes:
646*523fa7a6SAndroid Build Coastguard Worker        if nd.target is torch.ops.higher_order.cond:
647*523fa7a6SAndroid Build Coastguard Worker            yield nd
648*523fa7a6SAndroid Build Coastguard Worker
649*523fa7a6SAndroid Build Coastguard Worker
650*523fa7a6SAndroid Build Coastguard Workerdef get_while_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
651*523fa7a6SAndroid Build Coastguard Worker    for nd in graph_module.graph.nodes:
652*523fa7a6SAndroid Build Coastguard Worker        if nd.target is exir_while:
653*523fa7a6SAndroid Build Coastguard Worker            yield nd
654*523fa7a6SAndroid Build Coastguard Worker
655*523fa7a6SAndroid Build Coastguard Worker
656*523fa7a6SAndroid Build Coastguard Workerdef get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
657*523fa7a6SAndroid Build Coastguard Worker    for nd in graph_module.graph.nodes:
658*523fa7a6SAndroid Build Coastguard Worker        if nd.target is torch.ops.higher_order.map_impl:
659*523fa7a6SAndroid Build Coastguard Worker            yield nd
660*523fa7a6SAndroid Build Coastguard Worker
661*523fa7a6SAndroid Build Coastguard Worker
662*523fa7a6SAndroid Build Coastguard Workerdef get_return_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]:
663*523fa7a6SAndroid Build Coastguard Worker    return_specs = set()
664*523fa7a6SAndroid Build Coastguard Worker    nodes = graph_module.graph.nodes
665*523fa7a6SAndroid Build Coastguard Worker    if len(nodes) > 0:
666*523fa7a6SAndroid Build Coastguard Worker        last_node = next(iter(reversed(nodes)))
667*523fa7a6SAndroid Build Coastguard Worker        for spec in tree_flatten(last_node.meta["spec"])[0]:
668*523fa7a6SAndroid Build Coastguard Worker            return_specs.add(spec)
669*523fa7a6SAndroid Build Coastguard Worker    return return_specs
670*523fa7a6SAndroid Build Coastguard Worker
671*523fa7a6SAndroid Build Coastguard Worker
672*523fa7a6SAndroid Build Coastguard Workerdef get_input_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]:
673*523fa7a6SAndroid Build Coastguard Worker    input_specs = set()
674*523fa7a6SAndroid Build Coastguard Worker    nodes = graph_module.graph.nodes
675*523fa7a6SAndroid Build Coastguard Worker    for node in nodes:
676*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder":
677*523fa7a6SAndroid Build Coastguard Worker            for spec in tree_flatten(node.meta["spec"])[0]:
678*523fa7a6SAndroid Build Coastguard Worker                input_specs.add(spec)
679*523fa7a6SAndroid Build Coastguard Worker    return input_specs
680*523fa7a6SAndroid Build Coastguard Worker
681*523fa7a6SAndroid Build Coastguard Worker
682*523fa7a6SAndroid Build Coastguard Workerdef insert_calls_to_free(
683*523fa7a6SAndroid Build Coastguard Worker    graph_module: fx.GraphModule, allspecs: Set[TensorSpec]
684*523fa7a6SAndroid Build Coastguard Worker) -> None:
685*523fa7a6SAndroid Build Coastguard Worker    """
686*523fa7a6SAndroid Build Coastguard Worker    Insert calls to free for dynamic unbound tensors that goes out of lifetime.
687*523fa7a6SAndroid Build Coastguard Worker
688*523fa7a6SAndroid Build Coastguard Worker    Only handle the module itself. Submodule is handles in separate calls of
689*523fa7a6SAndroid Build Coastguard Worker    this function.
690*523fa7a6SAndroid Build Coastguard Worker
691*523fa7a6SAndroid Build Coastguard Worker    NOTE: this method will invalidate lifetime recorded in TensorSpec because
692*523fa7a6SAndroid Build Coastguard Worker    of extra free node added to the graph.
693*523fa7a6SAndroid Build Coastguard Worker    """
694*523fa7a6SAndroid Build Coastguard Worker    # Note: we should never free a output tensor
695*523fa7a6SAndroid Build Coastguard Worker    return_specs = get_return_specs(graph_module)
696*523fa7a6SAndroid Build Coastguard Worker    # Note: we should never free a input tensor since buffer for input tensor
697*523fa7a6SAndroid Build Coastguard Worker    # may be passed in from user.
698*523fa7a6SAndroid Build Coastguard Worker    input_specs = get_input_specs(graph_module)
699*523fa7a6SAndroid Build Coastguard Worker    idx_to_dead_specs = defaultdict(list)
700*523fa7a6SAndroid Build Coastguard Worker    for spec in allspecs:
701*523fa7a6SAndroid Build Coastguard Worker        if (
702*523fa7a6SAndroid Build Coastguard Worker            spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
703*523fa7a6SAndroid Build Coastguard Worker            and spec not in return_specs
704*523fa7a6SAndroid Build Coastguard Worker            and spec not in input_specs
705*523fa7a6SAndroid Build Coastguard Worker        ):
706*523fa7a6SAndroid Build Coastguard Worker            idx_to_dead_specs[spec.lifetime[1]].append(spec)
707*523fa7a6SAndroid Build Coastguard Worker
708*523fa7a6SAndroid Build Coastguard Worker    num_nodes = len(graph_module.graph.nodes)
709*523fa7a6SAndroid Build Coastguard Worker    # iterate in reverse order so inserted node does not disturbe node
710*523fa7a6SAndroid Build Coastguard Worker    # numbering.
711*523fa7a6SAndroid Build Coastguard Worker    for node, node_idx in zip(
712*523fa7a6SAndroid Build Coastguard Worker        reversed(graph_module.graph.nodes), range(num_nodes - 1, -1, -1)
713*523fa7a6SAndroid Build Coastguard Worker    ):
714*523fa7a6SAndroid Build Coastguard Worker        dead_specs = idx_to_dead_specs.get(node_idx, [])
715*523fa7a6SAndroid Build Coastguard Worker        if not dead_specs:
716*523fa7a6SAndroid Build Coastguard Worker            continue
717*523fa7a6SAndroid Build Coastguard Worker        with graph_module.graph.inserting_after(node):
718*523fa7a6SAndroid Build Coastguard Worker            for spec in dead_specs:
719*523fa7a6SAndroid Build Coastguard Worker                graph_module.graph.call_function(memory.free, (spec,))
720*523fa7a6SAndroid Build Coastguard Worker    graph_module.recompile()
721*523fa7a6SAndroid Build Coastguard Worker
722*523fa7a6SAndroid Build Coastguard Worker
723*523fa7a6SAndroid Build Coastguard Workerdef apply_algo(
724*523fa7a6SAndroid Build Coastguard Worker    algo: Callable[
725*523fa7a6SAndroid Build Coastguard Worker        [torch.fx.GraphModule, int, Optional[ExportGraphSignature], bool, bool],
726*523fa7a6SAndroid Build Coastguard Worker        List[int],
727*523fa7a6SAndroid Build Coastguard Worker    ],
728*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
729*523fa7a6SAndroid Build Coastguard Worker    alignment: int,
730*523fa7a6SAndroid Build Coastguard Worker    graph_signature: Optional[ExportGraphSignature] = None,
731*523fa7a6SAndroid Build Coastguard Worker    alloc_graph_input: bool = True,
732*523fa7a6SAndroid Build Coastguard Worker    alloc_graph_output: bool = True,
733*523fa7a6SAndroid Build Coastguard Worker) -> List[int]:
734*523fa7a6SAndroid Build Coastguard Worker    """
735*523fa7a6SAndroid Build Coastguard Worker    Recursively apply algo to graph_module and its submodules for control flow.
736*523fa7a6SAndroid Build Coastguard Worker
737*523fa7a6SAndroid Build Coastguard Worker    Quite naively right now since it does not take the following optimizations
738*523fa7a6SAndroid Build Coastguard Worker    into considerating:
739*523fa7a6SAndroid Build Coastguard Worker    1. for conditional structure, true branch and false true does not overlap
740*523fa7a6SAndroid Build Coastguard Worker       in lifetime and can share tensor storage
741*523fa7a6SAndroid Build Coastguard Worker    2. tensors inside a submodule (e.g. true branch) has opportunities to share
742*523fa7a6SAndroid Build Coastguard Worker       storage with tensors in the outer module.
743*523fa7a6SAndroid Build Coastguard Worker    TODO: make these optimizations once we have some baseline working.
744*523fa7a6SAndroid Build Coastguard Worker    """
745*523fa7a6SAndroid Build Coastguard Worker    specs = update_all_tensors_lifetime(graph_module, graph_signature)
746*523fa7a6SAndroid Build Coastguard Worker    bufsizes: List[int] = algo(
747*523fa7a6SAndroid Build Coastguard Worker        graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
748*523fa7a6SAndroid Build Coastguard Worker    )
749*523fa7a6SAndroid Build Coastguard Worker    insert_calls_to_free(graph_module, specs)
750*523fa7a6SAndroid Build Coastguard Worker
751*523fa7a6SAndroid Build Coastguard Worker    def handle_submodule(
752*523fa7a6SAndroid Build Coastguard Worker        submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
753*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
754*523fa7a6SAndroid Build Coastguard Worker        nonlocal bufsizes
755*523fa7a6SAndroid Build Coastguard Worker        assert submodule_nd.op == "get_attr"
756*523fa7a6SAndroid Build Coastguard Worker        submodule = getattr(graph_module, submodule_nd.target)
757*523fa7a6SAndroid Build Coastguard Worker        # memory planning for submodule need to be aware of the amount of
758*523fa7a6SAndroid Build Coastguard Worker        # buffer already allocated.
759*523fa7a6SAndroid Build Coastguard Worker        submodule.input_mem_buffer_sizes = bufsizes
760*523fa7a6SAndroid Build Coastguard Worker        bufsizes = apply_algo(
761*523fa7a6SAndroid Build Coastguard Worker            algo,
762*523fa7a6SAndroid Build Coastguard Worker            submodule,
763*523fa7a6SAndroid Build Coastguard Worker            alignment,
764*523fa7a6SAndroid Build Coastguard Worker            graph_signature,
765*523fa7a6SAndroid Build Coastguard Worker            alloc_graph_input=alloc_graph_input,
766*523fa7a6SAndroid Build Coastguard Worker            alloc_graph_output=True,
767*523fa7a6SAndroid Build Coastguard Worker        )
768*523fa7a6SAndroid Build Coastguard Worker        submodule.meta.update({"non_const_buffer_sizes": bufsizes})
769*523fa7a6SAndroid Build Coastguard Worker
770*523fa7a6SAndroid Build Coastguard Worker    for cond_node in get_cond_nodes(graph_module):
771*523fa7a6SAndroid Build Coastguard Worker        handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
772*523fa7a6SAndroid Build Coastguard Worker        handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))
773*523fa7a6SAndroid Build Coastguard Worker
774*523fa7a6SAndroid Build Coastguard Worker    for while_node in get_while_nodes(graph_module):
775*523fa7a6SAndroid Build Coastguard Worker        handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
776*523fa7a6SAndroid Build Coastguard Worker        handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
777*523fa7a6SAndroid Build Coastguard Worker    # TODO: Add test coverage for map operator once dynamo tracing is
778*523fa7a6SAndroid Build Coastguard Worker    # fully supported for this. T142287208
779*523fa7a6SAndroid Build Coastguard Worker    for map_node in get_map_nodes(graph_module):
780*523fa7a6SAndroid Build Coastguard Worker        handle_submodule(
781*523fa7a6SAndroid Build Coastguard Worker            typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
782*523fa7a6SAndroid Build Coastguard Worker        )
783*523fa7a6SAndroid Build Coastguard Worker
784*523fa7a6SAndroid Build Coastguard Worker    graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
785*523fa7a6SAndroid Build Coastguard Worker
786*523fa7a6SAndroid Build Coastguard Worker    return bufsizes
787