xref: /aosp_15_r20/external/pytorch/test/jit/test_exception.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerr"""
8*da0073e9SAndroid Build Coastguard WorkerTest TorchScript exception handling.
9*da0073e9SAndroid Build Coastguard Worker"""
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass TestException(TestCase):
13*da0073e9SAndroid Build Coastguard Worker    def test_pyop_exception_message(self):
14*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.jit.ScriptModule):
15*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
16*da0073e9SAndroid Build Coastguard Worker                super().__init__()
17*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(1, 10, kernel_size=5)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
20*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
21*da0073e9SAndroid Build Coastguard Worker                return self.conv(x)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker        foo = Foo()
24*da0073e9SAndroid Build Coastguard Worker        # testing that the correct error message propagates
25*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
26*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
27*da0073e9SAndroid Build Coastguard Worker        ):
28*da0073e9SAndroid Build Coastguard Worker            foo(torch.ones([123]))  # wrong size
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    def test_builtin_error_messsage(self):
31*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
34*da0073e9SAndroid Build Coastguard Worker            def close_match(x):
35*da0073e9SAndroid Build Coastguard Worker                return x.masked_fill(True)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
38*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
39*da0073e9SAndroid Build Coastguard Worker            "This op may not exist or may not be currently " "supported in TorchScript",
40*da0073e9SAndroid Build Coastguard Worker        ):
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
43*da0073e9SAndroid Build Coastguard Worker            def unknown_op(x):
44*da0073e9SAndroid Build Coastguard Worker                torch.set_anomaly_enabled(True)
45*da0073e9SAndroid Build Coastguard Worker                return x
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def test_exceptions(self):
48*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit(
49*da0073e9SAndroid Build Coastguard Worker            """
50*da0073e9SAndroid Build Coastguard Worker            def foo(cond):
51*da0073e9SAndroid Build Coastguard Worker                if bool(cond):
52*da0073e9SAndroid Build Coastguard Worker                    raise ValueError(3)
53*da0073e9SAndroid Build Coastguard Worker                return 1
54*da0073e9SAndroid Build Coastguard Worker        """
55*da0073e9SAndroid Build Coastguard Worker        )
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        cu.foo(torch.tensor(0))
58*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "3"):
59*da0073e9SAndroid Build Coastguard Worker            cu.foo(torch.tensor(1))
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        def foo(cond):
62*da0073e9SAndroid Build Coastguard Worker            a = 3
63*da0073e9SAndroid Build Coastguard Worker            if bool(cond):
64*da0073e9SAndroid Build Coastguard Worker                raise ArbitraryError(a, "hi")  # noqa: F821
65*da0073e9SAndroid Build Coastguard Worker                if 1 == 2:
66*da0073e9SAndroid Build Coastguard Worker                    raise ArbitraryError  # noqa: F821
67*da0073e9SAndroid Build Coastguard Worker            return a
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
70*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(foo)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        def exception_as_value():
73*da0073e9SAndroid Build Coastguard Worker            a = Exception()
74*da0073e9SAndroid Build Coastguard Worker            print(a)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
77*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(exception_as_value)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
80*da0073e9SAndroid Build Coastguard Worker        def foo_no_decl_always_throws():
81*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Hi")
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        # function that has no declared type but always throws set to None
84*da0073e9SAndroid Build Coastguard Worker        output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
85*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(output_type) == "NoneType")
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
88*da0073e9SAndroid Build Coastguard Worker        def foo_decl_always_throws():
89*da0073e9SAndroid Build Coastguard Worker            # type: () -> Tensor
90*da0073e9SAndroid Build Coastguard Worker            raise Exception("Hi")  # noqa: TRY002
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        output_type = next(foo_decl_always_throws.graph.outputs()).type()
93*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(output_type) == "Tensor")
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        def foo():
96*da0073e9SAndroid Build Coastguard Worker            raise 3 + 4
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
99*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(foo)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        # a escapes scope
102*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
103*da0073e9SAndroid Build Coastguard Worker        def foo():
104*da0073e9SAndroid Build Coastguard Worker            if 1 == 1:
105*da0073e9SAndroid Build Coastguard Worker                a = 1
106*da0073e9SAndroid Build Coastguard Worker            else:
107*da0073e9SAndroid Build Coastguard Worker                if 1 == 1:
108*da0073e9SAndroid Build Coastguard Worker                    raise Exception("Hi")  # noqa: TRY002
109*da0073e9SAndroid Build Coastguard Worker                else:
110*da0073e9SAndroid Build Coastguard Worker                    raise Exception("Hi")  # noqa: TRY002
111*da0073e9SAndroid Build Coastguard Worker            return a
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 1)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
116*da0073e9SAndroid Build Coastguard Worker        def tuple_fn():
117*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("hello", "goodbye")
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
120*da0073e9SAndroid Build Coastguard Worker            tuple_fn()
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
123*da0073e9SAndroid Build Coastguard Worker        def no_message():
124*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
127*da0073e9SAndroid Build Coastguard Worker            no_message()
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    def test_assertions(self):
130*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit(
131*da0073e9SAndroid Build Coastguard Worker            """
132*da0073e9SAndroid Build Coastguard Worker            def foo(cond):
133*da0073e9SAndroid Build Coastguard Worker                assert bool(cond), "hi"
134*da0073e9SAndroid Build Coastguard Worker                return 0
135*da0073e9SAndroid Build Coastguard Worker        """
136*da0073e9SAndroid Build Coastguard Worker        )
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        cu.foo(torch.tensor(1))
139*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
140*da0073e9SAndroid Build Coastguard Worker            cu.foo(torch.tensor(0))
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
143*da0073e9SAndroid Build Coastguard Worker        def foo(cond):
144*da0073e9SAndroid Build Coastguard Worker            assert bool(cond), "hi"
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        foo(torch.tensor(1))
147*da0073e9SAndroid Build Coastguard Worker        # we don't currently validate the name of the exception
148*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
149*da0073e9SAndroid Build Coastguard Worker            foo(torch.tensor(0))
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def test_python_op_exception(self):
152*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
153*da0073e9SAndroid Build Coastguard Worker        def python_op(x):
154*da0073e9SAndroid Build Coastguard Worker            raise Exception("bad!")  # noqa: TRY002
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
157*da0073e9SAndroid Build Coastguard Worker        def fn(x):
158*da0073e9SAndroid Build Coastguard Worker            return python_op(x)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
161*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "operation failed in the TorchScript interpreter"
162*da0073e9SAndroid Build Coastguard Worker        ):
163*da0073e9SAndroid Build Coastguard Worker            fn(torch.tensor(4))
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    def test_dict_expansion_raises_error(self):
166*da0073e9SAndroid Build Coastguard Worker        def fn(self):
167*da0073e9SAndroid Build Coastguard Worker            d = {"foo": 1, "bar": 2, "baz": 3}
168*da0073e9SAndroid Build Coastguard Worker            return {**d}
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
171*da0073e9SAndroid Build Coastguard Worker            torch.jit.frontend.NotSupportedError, "Dict expansion "
172*da0073e9SAndroid Build Coastguard Worker        ):
173*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    def test_custom_python_exception(self):
176*da0073e9SAndroid Build Coastguard Worker        class MyValueError(ValueError):
177*da0073e9SAndroid Build Coastguard Worker            pass
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
180*da0073e9SAndroid Build Coastguard Worker        def fn():
181*da0073e9SAndroid Build Coastguard Worker            raise MyValueError("test custom exception")
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
184*da0073e9SAndroid Build Coastguard Worker            torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
185*da0073e9SAndroid Build Coastguard Worker        ):
186*da0073e9SAndroid Build Coastguard Worker            fn()
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def test_custom_python_exception_defined_elsewhere(self):
189*da0073e9SAndroid Build Coastguard Worker        from jit.myexception import MyKeyError
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
192*da0073e9SAndroid Build Coastguard Worker        def fn():
193*da0073e9SAndroid Build Coastguard Worker            raise MyKeyError("This is a user defined key error")
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
196*da0073e9SAndroid Build Coastguard Worker            torch.jit.Error,
197*da0073e9SAndroid Build Coastguard Worker            "jit.myexception.MyKeyError: This is a user defined key error",
198*da0073e9SAndroid Build Coastguard Worker        ):
199*da0073e9SAndroid Build Coastguard Worker            fn()
200