xref: /aosp_15_r20/external/pytorch/test/jit/test_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
11*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
15*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
16*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
17*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
18*da0073e9SAndroid Build Coastguard Worker        "instead."
19*da0073e9SAndroid Build Coastguard Worker    )
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerclass TestModules(JitTestCase):
23*da0073e9SAndroid Build Coastguard Worker    def test_script_module_with_constants_list(self):
24*da0073e9SAndroid Build Coastguard Worker        """
25*da0073e9SAndroid Build Coastguard Worker        Test that a module that has __constants__ set to something
26*da0073e9SAndroid Build Coastguard Worker        that is not a set can be scripted.
27*da0073e9SAndroid Build Coastguard Worker        """
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        # torch.nn.Linear has a __constants__ attribute defined
30*da0073e9SAndroid Build Coastguard Worker        # and intialized to a list.
31*da0073e9SAndroid Build Coastguard Worker        class Net(torch.nn.Linear):
32*da0073e9SAndroid Build Coastguard Worker            x: torch.jit.Final[int]
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
35*da0073e9SAndroid Build Coastguard Worker                super().__init__(5, 10)
36*da0073e9SAndroid Build Coastguard Worker                self.x = 0
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Net(), (torch.randn(5),))
39