xref: /aosp_15_r20/external/pytorch/test/test_deploy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport textwrap
4*da0073e9SAndroid Build Coastguard Workerimport types
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
7*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._freeze import Freezer, PATH_MARKER
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerclass TestFreezer(TestCase):
11*da0073e9SAndroid Build Coastguard Worker    """Tests the freeze.py script"""
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    def test_compile_string(self):
14*da0073e9SAndroid Build Coastguard Worker        freezer = Freezer(True)
15*da0073e9SAndroid Build Coastguard Worker        code_str = textwrap.dedent(
16*da0073e9SAndroid Build Coastguard Worker            """
17*da0073e9SAndroid Build Coastguard Worker            class MyCls:
18*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
19*da0073e9SAndroid Build Coastguard Worker                    pass
20*da0073e9SAndroid Build Coastguard Worker            """
21*da0073e9SAndroid Build Coastguard Worker        )
22*da0073e9SAndroid Build Coastguard Worker        co = freezer.compile_string(code_str)
23*da0073e9SAndroid Build Coastguard Worker        num_co = 0
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        def verify_filename(co: types.CodeType):
26*da0073e9SAndroid Build Coastguard Worker            nonlocal num_co
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker            if not isinstance(co, types.CodeType):
29*da0073e9SAndroid Build Coastguard Worker                return
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(PATH_MARKER, co.co_filename)
32*da0073e9SAndroid Build Coastguard Worker            num_co += 1
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker            for nested_co in co.co_consts:
35*da0073e9SAndroid Build Coastguard Worker                verify_filename(nested_co)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker        verify_filename(co)
38*da0073e9SAndroid Build Coastguard Worker        # there is at least one nested code object besides the top level one
39*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(num_co >= 2)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
43*da0073e9SAndroid Build Coastguard Worker    run_tests()
44