xref: /aosp_15_r20/external/pytorch/test/jit/test_unsupported_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6
7import torch
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23# NOTE: FIXING FAILING TESTS
24# If you are seeing a test failure from this file, congrats, you improved
25# parity between JIT and Python API. Before you fix the test, you must also update
26# the corresponding section in documentation that states the unsupported behavior.
27# see: `jit_unsupported.rst`
28
29
30class TestUnsupportedOps(JitTestCase):
31    def test_factory_ops_requires_grad_fail(self):
32        # Keyword argument {name} unknown is a JIT-only error message,
33        # so these functions are succeeding in eager and failing in JIT
34
35        # Complete issue and set of ops is https://github.com/pytorch/pytorch/issues/30761
36        # only testing some because they should be fixed all at once
37        def ones():
38            return torch.ones([2], requires_grad=True)
39
40        with self.assertRaisesRegexWithHighlight(
41            Exception, "Keyword argument requires_grad unknown", "torch.ones"
42        ):
43            torch.jit.script(ones)
44
45        def randn():
46            return torch.randn([2], requires_grad=True)
47
48        with self.assertRaisesRegexWithHighlight(
49            Exception, "Keyword argument requires_grad unknown", "torch.randn"
50        ):
51            torch.jit.script(randn)
52
53        def zeros():
54            return torch.zeros([2], requires_grad=True)
55
56        with self.assertRaisesRegexWithHighlight(
57            Exception, "Keyword argument requires_grad unknown", "torch.zeros"
58        ):
59            torch.jit.script(zeros)
60
61    @unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")
62    def test_init_ops(self):
63        def calculate_gain():
64            return torch.nn.init.calculate_gain("leaky_relu", 0.2)
65
66        def eye_():
67            return torch.nn.init.eye_(torch.zeros([2, 2]))
68
69        def dirac_():
70            return torch.nn.init.dirac_(torch.empty(3, 16, 5, 5))
71
72        def kaiming_uniform_():
73            return torch.nn.init.kaiming_normal_(torch.empty(3, 5))
74
75        def orthogonal_():
76            return torch.nn.init.orthogonal_(torch.empty(3, 5))
77
78        def sparse():
79            return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=0.1)
80
81        for func in [
82            calculate_gain,
83            eye_,
84            dirac_,
85            kaiming_uniform_,
86            orthogonal_,
87            sparse,
88        ]:
89            # doesn't error in eager
90            func()
91            with self.assertRaisesRegex(Exception, ""):
92                torch.jit.script(func)
93