1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerr""" 8*da0073e9SAndroid Build Coastguard WorkerTest TorchScript exception handling. 9*da0073e9SAndroid Build Coastguard Worker""" 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass TestException(TestCase): 13*da0073e9SAndroid Build Coastguard Worker def test_pyop_exception_message(self): 14*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 15*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 16*da0073e9SAndroid Build Coastguard Worker super().__init__() 17*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(1, 10, kernel_size=5) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 20*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 21*da0073e9SAndroid Build Coastguard Worker return self.conv(x) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker foo = Foo() 24*da0073e9SAndroid Build Coastguard Worker # testing that the correct error message propagates 25*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 26*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d" 27*da0073e9SAndroid Build Coastguard Worker ): 28*da0073e9SAndroid Build Coastguard Worker foo(torch.ones([123])) # wrong size 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker def test_builtin_error_messsage(self): 31*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 34*da0073e9SAndroid Build Coastguard Worker def close_match(x): 35*da0073e9SAndroid Build Coastguard Worker return x.masked_fill(True) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 38*da0073e9SAndroid Build Coastguard Worker RuntimeError, 39*da0073e9SAndroid Build Coastguard Worker "This op may not exist or may not be currently " "supported in TorchScript", 40*da0073e9SAndroid Build Coastguard Worker ): 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 43*da0073e9SAndroid Build Coastguard Worker def unknown_op(x): 44*da0073e9SAndroid Build Coastguard Worker torch.set_anomaly_enabled(True) 45*da0073e9SAndroid Build Coastguard Worker return x 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def test_exceptions(self): 48*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit( 49*da0073e9SAndroid Build Coastguard Worker """ 50*da0073e9SAndroid Build Coastguard Worker def foo(cond): 51*da0073e9SAndroid Build Coastguard Worker if bool(cond): 52*da0073e9SAndroid Build Coastguard Worker raise ValueError(3) 53*da0073e9SAndroid Build Coastguard Worker return 1 54*da0073e9SAndroid Build Coastguard Worker """ 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker cu.foo(torch.tensor(0)) 58*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "3"): 59*da0073e9SAndroid Build Coastguard Worker cu.foo(torch.tensor(1)) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def foo(cond): 62*da0073e9SAndroid Build Coastguard Worker a = 3 63*da0073e9SAndroid Build Coastguard Worker if bool(cond): 64*da0073e9SAndroid Build Coastguard Worker raise ArbitraryError(a, "hi") # noqa: F821 65*da0073e9SAndroid Build Coastguard Worker if 1 == 2: 66*da0073e9SAndroid Build Coastguard Worker raise ArbitraryError # noqa: F821 67*da0073e9SAndroid Build Coastguard Worker return a 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"): 70*da0073e9SAndroid Build Coastguard Worker torch.jit.script(foo) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def exception_as_value(): 73*da0073e9SAndroid Build Coastguard Worker a = Exception() 74*da0073e9SAndroid Build Coastguard Worker print(a) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"): 77*da0073e9SAndroid Build Coastguard Worker torch.jit.script(exception_as_value) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 80*da0073e9SAndroid Build Coastguard Worker def foo_no_decl_always_throws(): 81*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Hi") 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker # function that has no declared type but always throws set to None 84*da0073e9SAndroid Build Coastguard Worker output_type = next(foo_no_decl_always_throws.graph.outputs()).type() 85*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(output_type) == "NoneType") 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 88*da0073e9SAndroid Build Coastguard Worker def foo_decl_always_throws(): 89*da0073e9SAndroid Build Coastguard Worker # type: () -> Tensor 90*da0073e9SAndroid Build Coastguard Worker raise Exception("Hi") # noqa: TRY002 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker output_type = next(foo_decl_always_throws.graph.outputs()).type() 93*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(output_type) == "Tensor") 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def foo(): 96*da0073e9SAndroid Build Coastguard Worker raise 3 + 4 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"): 99*da0073e9SAndroid Build Coastguard Worker torch.jit.script(foo) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker # a escapes scope 102*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 103*da0073e9SAndroid Build Coastguard Worker def foo(): 104*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 105*da0073e9SAndroid Build Coastguard Worker a = 1 106*da0073e9SAndroid Build Coastguard Worker else: 107*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 108*da0073e9SAndroid Build Coastguard Worker raise Exception("Hi") # noqa: TRY002 109*da0073e9SAndroid Build Coastguard Worker else: 110*da0073e9SAndroid Build Coastguard Worker raise Exception("Hi") # noqa: TRY002 111*da0073e9SAndroid Build Coastguard Worker return a 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 1) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 116*da0073e9SAndroid Build Coastguard Worker def tuple_fn(): 117*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("hello", "goodbye") 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"): 120*da0073e9SAndroid Build Coastguard Worker tuple_fn() 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 123*da0073e9SAndroid Build Coastguard Worker def no_message(): 124*da0073e9SAndroid Build Coastguard Worker raise RuntimeError 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"): 127*da0073e9SAndroid Build Coastguard Worker no_message() 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker def test_assertions(self): 130*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit( 131*da0073e9SAndroid Build Coastguard Worker """ 132*da0073e9SAndroid Build Coastguard Worker def foo(cond): 133*da0073e9SAndroid Build Coastguard Worker assert bool(cond), "hi" 134*da0073e9SAndroid Build Coastguard Worker return 0 135*da0073e9SAndroid Build Coastguard Worker """ 136*da0073e9SAndroid Build Coastguard Worker ) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker cu.foo(torch.tensor(1)) 139*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): 140*da0073e9SAndroid Build Coastguard Worker cu.foo(torch.tensor(0)) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 143*da0073e9SAndroid Build Coastguard Worker def foo(cond): 144*da0073e9SAndroid Build Coastguard Worker assert bool(cond), "hi" 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker foo(torch.tensor(1)) 147*da0073e9SAndroid Build Coastguard Worker # we don't currently validate the name of the exception 148*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): 149*da0073e9SAndroid Build Coastguard Worker foo(torch.tensor(0)) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def test_python_op_exception(self): 152*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 153*da0073e9SAndroid Build Coastguard Worker def python_op(x): 154*da0073e9SAndroid Build Coastguard Worker raise Exception("bad!") # noqa: TRY002 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 157*da0073e9SAndroid Build Coastguard Worker def fn(x): 158*da0073e9SAndroid Build Coastguard Worker return python_op(x) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 161*da0073e9SAndroid Build Coastguard Worker RuntimeError, "operation failed in the TorchScript interpreter" 162*da0073e9SAndroid Build Coastguard Worker ): 163*da0073e9SAndroid Build Coastguard Worker fn(torch.tensor(4)) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def test_dict_expansion_raises_error(self): 166*da0073e9SAndroid Build Coastguard Worker def fn(self): 167*da0073e9SAndroid Build Coastguard Worker d = {"foo": 1, "bar": 2, "baz": 3} 168*da0073e9SAndroid Build Coastguard Worker return {**d} 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 171*da0073e9SAndroid Build Coastguard Worker torch.jit.frontend.NotSupportedError, "Dict expansion " 172*da0073e9SAndroid Build Coastguard Worker ): 173*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def test_custom_python_exception(self): 176*da0073e9SAndroid Build Coastguard Worker class MyValueError(ValueError): 177*da0073e9SAndroid Build Coastguard Worker pass 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 180*da0073e9SAndroid Build Coastguard Worker def fn(): 181*da0073e9SAndroid Build Coastguard Worker raise MyValueError("test custom exception") 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 184*da0073e9SAndroid Build Coastguard Worker torch.jit.Error, "jit.test_exception.MyValueError: test custom exception" 185*da0073e9SAndroid Build Coastguard Worker ): 186*da0073e9SAndroid Build Coastguard Worker fn() 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def test_custom_python_exception_defined_elsewhere(self): 189*da0073e9SAndroid Build Coastguard Worker from jit.myexception import MyKeyError 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 192*da0073e9SAndroid Build Coastguard Worker def fn(): 193*da0073e9SAndroid Build Coastguard Worker raise MyKeyError("This is a user defined key error") 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 196*da0073e9SAndroid Build Coastguard Worker torch.jit.Error, 197*da0073e9SAndroid Build Coastguard Worker "jit.myexception.MyKeyError: This is a user defined key error", 198*da0073e9SAndroid Build Coastguard Worker ): 199*da0073e9SAndroid Build Coastguard Worker fn() 200