xref: /aosp_15_r20/external/executorch/exir/program/test/test_fake_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import sys
10import unittest
11
12import torch
13
14from executorch.exir.program._fake_program import (
15    get_fake_program,
16    update_to_real_program,
17)
18from torch.export import export, ExportedProgram
19
20
21def get_exported_program() -> ExportedProgram:
22    class Linear(torch.nn.Module):
23        def __init__(self):
24            super().__init__()
25            self.linear = torch.nn.Linear(10, 10)
26            self.register_buffer("buf", torch.randn(10, 10), persistent=False)
27
28        def forward(self, arg) -> torch.Tensor:
29            return self.linear(arg) + self.buf
30
31    linear = Linear()
32    exported_program = export(
33        linear,
34        args=(torch.randn(10, 10),),
35    ).run_decompositions()
36    return exported_program
37
38
39class TestFakeProgram(unittest.TestCase):
40    def setUp(self) -> None:
41        super().setUp()
42
43    def test_fake_program(self) -> None:
44        exported_program = get_exported_program()
45        fake_program = get_fake_program(exported_program)
46        print(f"Exported program size: {sys.getsizeof(exported_program.state_dict)}")
47        print(f"Fake program size: {sys.getsizeof(fake_program.state_dict)}")
48
49        # Fake program deep copies attributes besides verifier, state_dict and constants.
50        self.assertEqual(exported_program.graph_signature, fake_program.graph_signature)
51        self.assertNotEqual(
52            id(exported_program.graph_signature), id(fake_program.graph_signature)
53        )
54        self.assertEqual(
55            exported_program.module_call_graph, fake_program.module_call_graph
56        )
57        self.assertNotEqual(
58            id(exported_program.module_call_graph), id(fake_program.module_call_graph)
59        )
60
61        # Verifier is static.
62        self.assertEqual(exported_program.verifier, fake_program.verifier)
63        self.assertEqual(id(exported_program.verifier), id(fake_program.verifier))
64
65        # Fake program uses fake tensors for the state dict. Size should be not be larger.
66        self.assertLessEqual(
67            sys.getsizeof(fake_program.state_dict),
68            sys.getsizeof(exported_program.state_dict),
69        )
70
71        # Do not copy constants.
72        self.assertEqual(exported_program.constants, fake_program.constants)
73        self.assertEqual(id(exported_program.constants), id(fake_program.constants))
74
75        update_to_real_program(fake_program, exported_program)
76        self.assertEqual(exported_program.state_dict, fake_program.state_dict)
77        self.assertEqual(
78            exported_program.state_dict.keys(), fake_program.state_dict.keys()
79        )
80