xref: /aosp_15_r20/external/pytorch/test/jit/test_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5
6import torch
7from torch.testing._internal.jit_utils import JitTestCase
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13
14if __name__ == "__main__":
15    raise RuntimeError(
16        "This test file is not meant to be run directly, use:\n\n"
17        "\tpython test/test_jit.py TESTNAME\n\n"
18        "instead."
19    )
20
21
22class TestModules(JitTestCase):
23    def test_script_module_with_constants_list(self):
24        """
25        Test that a module that has __constants__ set to something
26        that is not a set can be scripted.
27        """
28
29        # torch.nn.Linear has a __constants__ attribute defined
30        # and intialized to a list.
31        class Net(torch.nn.Linear):
32            x: torch.jit.Final[int]
33
34            def __init__(self) -> None:
35                super().__init__(5, 10)
36                self.x = 0
37
38        self.checkModule(Net(), (torch.randn(5),))
39