xref: /aosp_15_r20/external/pytorch/test/test_deploy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import textwrap
4import types
5
6from torch.testing._internal.common_utils import run_tests, TestCase
7from torch.utils._freeze import Freezer, PATH_MARKER
8
9
10class TestFreezer(TestCase):
11    """Tests the freeze.py script"""
12
13    def test_compile_string(self):
14        freezer = Freezer(True)
15        code_str = textwrap.dedent(
16            """
17            class MyCls:
18                def __init__(self) -> None:
19                    pass
20            """
21        )
22        co = freezer.compile_string(code_str)
23        num_co = 0
24
25        def verify_filename(co: types.CodeType):
26            nonlocal num_co
27
28            if not isinstance(co, types.CodeType):
29                return
30
31            self.assertEqual(PATH_MARKER, co.co_filename)
32            num_co += 1
33
34            for nested_co in co.co_consts:
35                verify_filename(nested_co)
36
37        verify_filename(co)
38        # there is at least one nested code object besides the top level one
39        self.assertTrue(num_co >= 2)
40
41
42if __name__ == "__main__":
43    run_tests()
44