xref: /aosp_15_r20/external/executorch/exir/tests/test_memory_planning.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import itertools
10import unittest
11from typing import Any, Callable, List, Optional, Tuple, Type
12
13import executorch.exir as exir
14
15import torch
16from executorch.exir import ExecutorchBackendConfig, to_edge
17from executorch.exir.memory_planning import (
18    filter_nodes,
19    get_node_tensor_specs,
20    greedy,
21    naive,
22    Verifier,
23)
24from executorch.exir.pass_base import PassResult
25from executorch.exir.pass_manager import PassManager
26from executorch.exir.passes import (  # noqa
27    MemoryPlanningPass,
28    SpecPropPass,
29    ToOutVarPass,
30)
31from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
32from parameterized import parameterized
33
34from torch import nn
35from torch.ao.quantization import (  # @manual=//caffe2:torch
36    float_qparams_weight_only_qconfig,
37)
38from torch.ao.quantization.backend_config.executorch import (
39    get_executorch_backend_config,
40)
41from torch.ao.quantization.observer import (
42    default_dynamic_quant_observer,
43    default_per_channel_weight_observer,
44)
45from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping
46from torch.ao.quantization.quantize_fx import (
47    _convert_to_reference_decomposed_fx,
48    prepare_fx,
49)
50from torch.export import export
51from torch.export.exported_program import ExportGraphSignature
52from torch.fx import Graph, GraphModule, Node
53from torch.nn import functional as F
54
55torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib")
56
57
58def swap_modules(
59    module: torch.nn.Module,
60    condition: Callable[[torch.nn.Module], bool],
61    convert_func: Callable[[torch.nn.Module], torch.nn.Module],
62) -> None:
63    reassign = {}
64    for name, mod in module.named_children():
65        swap_modules(mod, condition, convert_func)
66        if condition(mod):
67            out = convert_func(mod)
68            reassign[name] = out
69    for key, value in reassign.items():
70        module._modules[key] = value
71
72
73class ToyModelForMemPlanning(torch.nn.Module):
74    def __init__(self) -> None:
75        super(ToyModelForMemPlanning, self).__init__()
76
77    def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
78        o = a
79        for _ in range(10):
80            o = o * a
81            o = o + b
82        return o
83
84    def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
85        return (torch.randn(10), torch.randn(10))
86
87
88class ModelWithDifferentTensorSizes(torch.nn.Module):
89    def __init__(self) -> None:
90        super(ModelWithDifferentTensorSizes, self).__init__()
91        self.linears = torch.nn.ModuleList()
92        for x in [2, 4, 8, 16, 32, 64, 128]:
93            self.linears.append(torch.nn.Linear(x, x * 2))
94
95    def forward(self, i: torch.Tensor) -> torch.Tensor:
96        o1 = i
97        for linear in self.linears:
98            o1 = linear(o1)
99        o2 = i
100        for linear in self.linears:
101            o2 = linear(o2)
102        return o1 + o2
103
104    def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
105        return (torch.randn(2),)
106
107
108class ModuleReturnTwo(nn.Module):
109    def __init__(self) -> None:
110        super(ModuleReturnTwo, self).__init__()
111        self.linear1 = nn.Linear(8, 8)
112        self.linear2 = nn.Linear(8, 8)
113
114    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
115        o1 = self.linear1(x)
116        o2 = self.linear2(x)
117        return o1, o2
118
119    def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
120        return (torch.randn(8),)
121
122
123class ModuleListArg(nn.Module):
124    r"""
125    The module split a tensor and concat the parts again. The cat op will receive
126    a list of tensors as argument. We want to make sure we can handle lifetime
127    of tensors embedded inside a list arg correctly.
128    """
129
130    def __init__(self) -> None:
131        super(ModuleListArg, self).__init__()
132
133    def forward(self, a: torch.Tensor) -> torch.Tensor:
134        s0, s1 = torch.tensor_split(a, 2)
135        s = torch.cat([s0, s1], 0)
136        return s
137
138    def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
139        return (torch.randn(8),)
140
141    @staticmethod
142    def extra_check(
143        testcase: unittest.TestCase, graph_module: torch.fx.GraphModule
144    ) -> None:
145        """
146        Make sure the getitem nodes live as long as when the cat node starts alive
147        since the cat node should have a list argument containing all the getitem nodes.
148        """
149        getitem_specs = []
150        cat_specs = []
151        for node in graph_module.graph.nodes:
152            if node.target == torch.ops.aten.cat.out:
153                cat_specs.append(node.meta["spec"])
154            elif node.target == torch.ops.aten.slice_copy.Tensor_out:
155                getitem_specs.append(node.meta["spec"])
156
157        testcase.assertEqual(1, len(cat_specs))
158        testcase.assertEqual(2, len(getitem_specs))
159        for getitem_spec in getitem_specs:
160            testcase.assertTrue(getitem_spec.lifetime[1] >= cat_specs[0].lifetime[0])
161
162
163class CustomPoolMemoryPlanningPass(MemoryPlanningPass):
164    def call(self, graph_module: GraphModule) -> PassResult:
165        for subgm in graph_module.modules():
166            if not isinstance(subgm, GraphModule):
167                continue
168            for node in subgm.graph.nodes:
169                # mem_id = 1 placeholder and outputs of mul
170                # mem_id = 3 for outputs of add
171                # parent class will copy spec will to alloc nodes
172                if node.op == "placeholder":
173                    node.meta["spec"].mem_id = 1
174                    continue
175
176                if node.op != "call_function":
177                    continue
178
179                if node.target == torch.ops.aten.add.out:
180                    node.meta["spec"].mem_id = 3
181                elif node.target == torch.ops.aten.mul.out:
182                    node.meta["spec"].mem_id = 1
183
184        return super().run(graph_module)
185
186    def run(
187        self,
188        graph_module: torch.fx.GraphModule,
189        graph_signature: Optional[ExportGraphSignature] = None,
190    ) -> PassResult:
191        return self.call(graph_module)
192
193
194class MultiplePoolsToyModel(torch.nn.Module):
195    def forward(self, a: torch.Tensor) -> torch.Tensor:
196        # a: mem_id = 1, offset = 0
197        # b: mem_id = 3, offset = 0
198        # c: mem_id = 1, offset = 4
199        # d: mem_id = 3, offset = 4
200        # greedy:
201        # e: mem_id = 1, offset = 0
202        # naive:
203        # e: mem_id = 1, offset = 8
204        b = a + a
205        c = a * b
206        d = c + b
207        e = c * d
208        return e
209
210
211def maketest(
212    module_cls: Type[torch.nn.Module],
213    criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
214    extra_check: Optional[Callable[..., None]] = None,
215    use_functionalization: bool = True,
216    alloc_graph_input: bool = True,
217    alloc_graph_output: bool = True,
218    has_unused_graph_input: bool = False,
219) -> Callable[..., None]:
220    # parameterized.expand is not compatible with maketest. I'll just loop thru
221    # the test setups in the wrapper.
222    def wrapper(self: "TestMemoryPlanning") -> None:
223        nonlocal criteria
224        if not criteria:
225            criteria = [
226                # naive algorithm does not reuse tensor storages
227                (naive, False),
228                # greedy algorithm should reuse tensor storages in the testing model
229                (greedy, True),
230            ]
231
232        for algo, expect_reuse in criteria:
233            print(
234                f"algo {getattr(algo, '__name__', repr(algo))}, expect_reuse {expect_reuse}"
235            )
236            eager_module = module_cls().eval()
237            # pyre-fixme[29]: `Union[nn.modules.module.Module,
238            #  torch._tensor.Tensor]` is not a function.
239            inputs = eager_module.get_random_inputs()
240            graph_module = (
241                to_edge(
242                    export(
243                        eager_module,
244                        inputs,
245                    )
246                )
247                .exported_program()
248                .graph_module
249            )
250
251            graph_module = PassManager(
252                passes=[
253                    SpecPropPass(),
254                    ToOutVarPass(),
255                    MemoryPlanningPass(
256                        algo,
257                        alloc_graph_input=alloc_graph_input,
258                        alloc_graph_output=alloc_graph_output,
259                    ),
260                ],
261            )(graph_module).graph_module
262
263            self.verify_reuse(
264                graph_module, expect_reuse, alloc_graph_input, alloc_graph_output
265            )
266            self.verify_graph_input_output(
267                graph_module, alloc_graph_input, alloc_graph_output
268            )
269
270            self.verify_overlap_placeholders(has_unused_graph_input, graph_module)
271
272            # print(f"Final code: {graph_module.code}")
273            # print(f"Final graph: {graph_module.graph}")
274
275            if extra_check:
276                extra_check(self, graph_module)
277
278    return wrapper
279
280
281class TestMemoryPlanning(unittest.TestCase):
282    def verify_reuse(
283        self,
284        graph_module: torch.fx.GraphModule,
285        expect_reuse: bool,
286        alloc_graph_input: bool,
287        alloc_graph_output: bool,
288    ) -> None:
289        r"""
290        Do sanity check and verify tensor storage reuse.
291
292        There should NOT be any tensor storage overlapping between tensors that have
293        overlapping lifetime.
294
295        expect_reuse is True if we expect the algorithm reuse tensor storages
296        for at least a pair of tensors in the current testing setup.
297        """
298        # this method throws if 2 tensors overlap both lifetime and storage.
299        num_reuse_pairs = Verifier(
300            graph_module,
301            alloc_graph_input=alloc_graph_input,
302            alloc_graph_output=alloc_graph_output,
303        ).verify_storage_reuse()
304
305        print(f"num_reuse_pairs is {num_reuse_pairs}")
306        if expect_reuse:
307            self.assertTrue(num_reuse_pairs > 0)
308        else:
309            self.assertTrue(num_reuse_pairs == 0)
310
311    def verify_graph_input_output(
312        self,
313        graph_module: torch.fx.GraphModule,
314        alloc_graph_input: bool,
315        alloc_graph_output: bool,
316    ) -> None:
317        Verifier(
318            graph_module, alloc_graph_input, alloc_graph_output
319        ).verify_graph_input_output()
320
321    def verify_overlap_placeholders(
322        self, has_unused_graph_input: bool, graph_module: GraphModule
323    ) -> None:
324        """
325        If every placholder node is used somewhere, then each pair should have
326        overlapped lifetime.
327        """
328        if has_unused_graph_input:
329            return
330
331        ph_list = []
332        for nd in graph_module.graph.nodes:
333            if nd.op == "placeholder":
334                ph_list.append(nd)
335
336        # since all placeholders are used somewhere. Their lifetime should
337        # overlap.
338        for i in range(len(ph_list)):
339            for j in range(i + 1, len(ph_list)):
340                ph_lhs = ph_list[i]
341                ph_rhs = ph_list[j]
342                self.assertTrue(
343                    Verifier.lifetime_overlap(ph_lhs.meta["spec"], ph_rhs.meta["spec"])
344                )
345
346    test_basic: Callable[..., None] = maketest(ToyModelForMemPlanning)
347    # TODO(zhxchen17) re-enable this.
348    # test_while: Callable[..., None] = maketest(
349    #     ModuleWhile,
350    #     criteria=[
351    #         ("naive", False),
352    #         ("greedy", False),
353    #     ],
354    # )
355    test_different_tensor_sizes: Callable[..., None] = maketest(
356        ModelWithDifferentTensorSizes
357    )
358
359    test_return_two: Callable[..., None] = maketest(
360        ModuleReturnTwo,
361        criteria=[
362            (naive, False),
363            (greedy, True),
364        ],
365    )
366
367    # greedy algorithm will reuse memory if we let the algorithm allocate
368    # memory for both graph input and output.
369    test_list_arg: Callable[..., None] = maketest(
370        ModuleListArg,
371        criteria=[
372            (naive, False),
373            (greedy, True),
374        ],
375        extra_check=ModuleListArg.extra_check,
376    )
377
378    def test_graph_input_output(self) -> None:
379        for alloc_graph_input, alloc_graph_output in itertools.product(
380            [True, False], [True, False]
381        ):
382            case = maketest(
383                ModelWithDifferentTensorSizes,
384                alloc_graph_input=alloc_graph_input,
385                alloc_graph_output=alloc_graph_output,
386            )
387            case(self)
388
389
390class TestVerifier(unittest.TestCase):
391    def test_overlap(self) -> None:
392        # first enclose second
393        self.assertTrue(Verifier.has_overlap([1, 10], [2, 3]))
394        # second enclose first
395        self.assertTrue(Verifier.has_overlap([2, 3], [1, 10]))
396        # first on the left side
397        self.assertTrue(Verifier.has_overlap([1, 4], [2, 5]))
398        # first on the right side
399        self.assertTrue(Verifier.has_overlap([2, 5], [1, 4]))
400
401        # non overlap. first on the left side
402        self.assertFalse(Verifier.has_overlap([1, 2], [5, 6]))
403        # non overlap. first on the right side
404        self.assertFalse(Verifier.has_overlap([5, 6], [1, 2]))
405
406
407class TestMisc(unittest.TestCase):
408    def test_filter_nodes(self) -> None:
409        g = Graph()
410        nd_pool = [
411            Node(g, f"n{idx}", "placeholder", f"n{idx}", (), {}) for idx in range(10)
412        ]
413        actual_list = list(
414            filter_nodes(
415                [
416                    nd_pool[0],
417                    (nd_pool[1], nd_pool[2]),
418                    None,
419                    [nd_pool[3]],
420                    {"first": nd_pool[4]},
421                ]
422            )
423        )
424        expected_list = nd_pool[:5]
425        self.assertEqual(len(actual_list), len(expected_list))
426        for act, exp in zip(actual_list, expected_list):
427            self.assertEqual(id(act), id(exp))
428
429    def quantize(self, eager_model: nn.Module) -> nn.Module:
430        quantized_model = eager_model
431        linear_qconfig_mapping = QConfigMapping().set_object_type(
432            F.linear,
433            QConfig(
434                activation=default_dynamic_quant_observer,
435                weight=default_per_channel_weight_observer,
436            ),
437        )
438        embedding_qconfig_mapping = QConfigMapping().set_object_type(
439            F.embedding,
440            float_qparams_weight_only_qconfig,
441        )
442        # quantize module
443        swap_modules(
444            quantized_model,
445            lambda mod: isinstance(mod, torch.nn.Linear),
446            lambda mod: _convert_to_reference_decomposed_fx(
447                prepare_fx(
448                    mod,
449                    linear_qconfig_mapping,
450                    (torch.rand(1, mod.in_features),),
451                    backend_config=get_executorch_backend_config(),
452                ),
453                backend_config=get_executorch_backend_config(),
454            ),
455        )
456        swap_modules(
457            quantized_model,
458            lambda mod: isinstance(mod, torch.nn.Embedding),
459            lambda mod: _convert_to_reference_decomposed_fx(
460                prepare_fx(
461                    mod,
462                    embedding_qconfig_mapping,
463                    (torch.ones(1, 1),),
464                    backend_config=get_executorch_backend_config(),
465                ),
466                backend_config=get_executorch_backend_config(),
467            ),
468        )
469        return quantized_model
470
471    # pyre-ignore
472    @parameterized.expand(
473        [
474            (
475                naive,
476                [(1, 0), (3, 0), (1, 4), (3, 4), (1, 8)],
477                [0, 12, 0, 8],
478            ),
479            (
480                greedy,
481                [(1, 0), (3, 0), (1, 4), (3, 4), (1, 0)],
482                [0, 8, 0, 8],
483            ),
484        ]
485    )
486    def test_multiple_pools(
487        self,
488        algo: Callable[..., List[int]],
489        expected_allocs: List[Tuple[int, int]],
490        expected_bufsizes: List[int],
491    ) -> None:
492        edge_program = to_edge(
493            export(
494                MultiplePoolsToyModel(),
495                (torch.ones(1),),
496            )
497        )
498
499        edge_program.to_executorch(
500            exir.ExecutorchBackendConfig(
501                memory_planning_pass=CustomPoolMemoryPlanningPass(
502                    memory_planning_algo=algo,
503                    alignment=1,
504                ),
505            )
506        )
507        graph_module = edge_program.exported_program().graph_module
508
509        verifier = Verifier(
510            graph_module,
511            alloc_graph_input=True,
512            alloc_graph_output=True,
513        )
514        verifier.verify_storage_reuse()
515        verifier.verify_graph_input_output()
516
517        idx = 0
518        for node in graph_module.graph.nodes:
519            if node.op == "placeholder" or (
520                node.op == "call_function"
521                and node.target in (torch.ops.aten.add.out, torch.ops.aten.mul.out)
522            ):
523                mem_id, mem_offset = expected_allocs[idx]
524                self.assertEqual(node.meta["spec"].mem_id, mem_id)
525                self.assertEqual(node.meta["spec"].mem_offset, mem_offset)
526                idx += 1
527        self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes)
528
529    def test_constants_not_memory_planned(self) -> None:
530        class Simple(torch.nn.Module):
531            def __init__(self) -> None:
532                super().__init__()
533                self.linear = torch.nn.Linear(5, 5)
534                self.register_buffer("constant", torch.ones(5, 5))
535
536            def forward(self, x: torch.Tensor) -> torch.Tensor:
537                return torch.nn.functional.sigmoid(self.linear(x) + self.constant + 1)
538
539        def count_planned_inputs(
540            nodes: List[Node], graph_signature: Any  # pyre-ignore
541        ) -> Tuple[int, int]:
542            num_mem_planned_placeholders = 0
543            num_placeholders = 0
544            for node in nodes:
545                if node.op == "placeholder":
546                    num_placeholders += 1
547                    specs = get_node_tensor_specs(node)
548                    self.assertGreaterEqual(len(specs), 1)
549                    for spec in specs:
550                        if spec.mem_id is not None:
551                            num_mem_planned_placeholders += 1
552            return num_placeholders, num_mem_planned_placeholders
553
554        model = Simple()
555        inputs = (torch.randn(5, 5),)
556
557        ep_no_input_planning = to_edge(export(model, inputs)).to_executorch(
558            config=ExecutorchBackendConfig(
559                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
560                sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
561            )
562        )
563
564        num_placeholders, num_planned_placeholders = count_planned_inputs(
565            ep_no_input_planning.exported_program().graph_module.graph.nodes,
566            ep_no_input_planning.exported_program().graph_signature,
567        )
568        self.assertEqual(
569            num_planned_placeholders,
570            0,
571        )  # one unplanned user input and 4 constants that shouldnt be planned
572        self.assertEqual(
573            num_placeholders,
574            5,  # x, self.constant, linear weight, linear bias, '1' scalar promoted to tensor
575        )
576
577        ep_input_planning = to_edge(export(model, inputs)).to_executorch(
578            config=ExecutorchBackendConfig(
579                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True),
580                sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
581            )
582        )
583
584        num_placeholders, num_planned_placeholders = count_planned_inputs(
585            ep_input_planning.exported_program().graph_module.graph.nodes,
586            ep_input_planning.exported_program().graph_signature,
587        )
588        self.assertEqual(
589            num_planned_placeholders,
590            1,
591        )  # one planned user input and 4 constants that shouldnt be planned
592        self.assertEqual(
593            num_placeholders,
594            5,
595        )
596