1# Owner(s): ["module: dynamo"] 2 3import torch 4import torch._dynamo.config 5import torch._dynamo.test_case 6import torch._functorch.config 7import torch.nn 8import torch.utils.checkpoint 9 10 11class ExceptionTests(torch._dynamo.test_case.TestCase): 12 def test_exception(self): 13 def fn(x): 14 x = torch.cos(x) 15 try: 16 x = torch.sin(x) 17 raise NotImplementedError 18 except Exception: 19 x = torch.sigmoid(x) 20 21 return x 22 23 x = torch.randn(4) 24 ref = fn(x) 25 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 26 res = opt_fn(x) 27 self.assertEqual(ref, res) 28 29 def test_exception2(self): 30 def fn(x): 31 x = torch.cos(x) 32 try: 33 x = torch.sin(x) 34 raise NotImplementedError 35 except (NotImplementedError, AttributeError) as e: 36 x = torch.sigmoid(x) 37 38 return x 39 40 x = torch.randn(4) 41 ref = fn(x) 42 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 43 res = opt_fn(x) 44 self.assertEqual(ref, res) 45 46 def test_exception3(self): 47 def fn(x): 48 x = torch.cos(x) 49 try: 50 x = torch.sin(x) 51 raise NotImplementedError("Not implemented") 52 except AssertionError: 53 x = torch.sigmoid(x) 54 except NotImplementedError: 55 x = torch.cos(x) 56 finally: 57 x = torch.cos(x) 58 59 return x 60 61 x = torch.randn(4) 62 ref = fn(x) 63 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 64 res = opt_fn(x) 65 self.assertEqual(ref, res) 66 67 def test_exception4(self): 68 def fn(x): 69 for i in range(10): 70 if i == 5: 71 return x 72 try: 73 x = torch.sin(x) 74 raise NotImplementedError 75 except Exception: 76 x = torch.sigmoid(x) 77 78 return x 79 80 x = torch.randn(4) 81 ref = fn(x) 82 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 83 res = opt_fn(x) 84 self.assertEqual(ref, res) 85 86 def test_exception_with_another_exception(self): 87 def fn(x): 88 x = torch.cos(x) 89 try: 90 x = torch.sin(x) 91 raise NotImplementedError("Not implemented") 92 except NotImplementedError as e: 93 x = torch.sigmoid(x) 94 try: 95 x = torch.cos(x) 96 raise AssertionError 97 except AssertionError: 98 x = torch.cos(x) 99 100 x = torch.randn(4) 101 ref = fn(x) 102 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 103 res = opt_fn(x) 104 self.assertEqual(ref, res) 105 106 def test_exception_else(self): 107 def gn(x): 108 return torch.cos(x) 109 110 def fn(x): 111 x = torch.cos(x) 112 try: 113 x = torch.sin(x) 114 x = gn(x) 115 except Exception: 116 x = torch.sigmoid(x) 117 else: 118 x = torch.cos(x) 119 120 return x 121 122 x = torch.randn(4) 123 ref = fn(x) 124 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 125 res = opt_fn(x) 126 self.assertEqual(ref, res) 127 128 # TODO(anijain2305) - does not work with fullgraph=True 129 def test_exception_with_another_exception2(self): 130 def gn(x): 131 try: 132 x = torch.cos(x) 133 raise NotImplementedError("Not implemented") 134 except NotImplementedError as e: 135 x = torch.sigmoid(x) 136 raise 137 138 def fn(x): 139 try: 140 x = torch.cos(x) 141 gn(x) 142 except Exception: 143 pass 144 return x 145 146 x = torch.randn(4) 147 ref = fn(x) 148 # Cant use fullgraph=True because RERAISE is not supported 149 opt_fn = torch.compile(fn, backend="eager") 150 res = opt_fn(x) 151 152 # TODO(anijain2305) - does not work with fullgraph=True 153 def test_exception_with_ctx_manager(self): 154 def fn(x): 155 x = torch.cos(x) 156 try: 157 with torch.no_grad(): 158 x = torch.sin(x) 159 raise NotImplementedError("Not implemented") 160 except NotImplementedError as e: 161 x = torch.sigmoid(x) 162 return x 163 164 x = torch.randn(4) 165 ref = fn(x) 166 # Cant use fullgraph=True because WITH_EXCEPT_START is not supported 167 opt_fn = torch.compile(fn, backend="eager") 168 res = opt_fn(x) 169 self.assertEqual(ref, res) 170 171 def test_exception_raised_from_child(self): 172 def gn(): 173 raise NotImplementedError("foo") 174 175 def fn(x): 176 x = torch.cos(x) 177 try: 178 x = torch.sin(x) 179 gn() 180 x = torch.sin(x) 181 except Exception: 182 x = torch.sigmoid(x) 183 184 return x 185 186 x = torch.randn(4) 187 ref = fn(x) 188 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 189 res = opt_fn(x) 190 self.assertEqual(ref, res) 191 192 def test_dynamo_undo_kw_names(self): 193 def g(x, k=None): 194 if k: 195 raise TypeError("error") 196 return x.sin() 197 198 def fn(x): 199 d = {"a": x} 200 try: 201 g(x, k=True) 202 except Exception: 203 y = 0 204 for _, b in d.items(): # noqa: PERF102 205 y += b.sum() 206 return y 207 208 x = torch.randn(2, 3) 209 expected = fn(x) 210 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 211 got = opt_fn(x) 212 self.assertEqual(expected, got) 213 214 def test_nn_module_getattr(self): 215 class A: 216 def __init__(self) -> None: 217 self._b = 20 218 219 def __getattr__(self, name): 220 fixed_name = "_" + name 221 if fixed_name in self.__dict__: 222 return self.__dict__[fixed_name] 223 raise AttributeError(f"{name} absent") 224 225 class B(A): 226 def __init__(self) -> None: 227 self.a = 10 228 229 def __getattr__(self, name): 230 try: 231 return super().__getattr__(name) 232 except AttributeError: 233 return 30 234 235 obj = B() 236 237 def fn(x): 238 return x * obj.a * obj.b * obj.c 239 240 x = torch.ones(4) 241 ref = fn(x) 242 print(ref) 243 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 244 res = opt_fn(x) 245 self.assertEqual(ref, res) 246 247 @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) 248 def test_custom_getattr_on_module_exception(self): 249 class Foo(torch.nn.Module): 250 def __init__(self, a=3): 251 super().__init__() 252 self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2)) 253 254 def __getattr__(self, name): 255 try: 256 return super().__getattr__(name) # defer to nn.Module's logic 257 except AttributeError: 258 if name == "a_copy": 259 return self.a 260 raise 261 262 def forward(self, x): 263 return x * self.a * self.a_copy 264 265 mod = Foo() 266 opt_mod = torch.compile(mod, backend="eager", fullgraph=True) 267 268 x = torch.ones(4) 269 self.assertEqual(mod(x), opt_mod(x)) 270 271 def test_attribute_error_from_getattr(self): 272 class Mock: 273 def __init__(self): 274 self.a = 5 275 276 def __getattr__(self, name): 277 if name != "a": 278 raise AttributeError("missing") 279 return self.__dict__["a"] 280 281 mock = Mock() 282 283 def fn(x): 284 if hasattr(mock, "b"): 285 return torch.cos(x) 286 return torch.sin(x) 287 288 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 289 x = torch.randn(4) 290 ref = fn(x) 291 res = opt_fn(x) 292 self.assertEqual(ref, res) 293 294 def test_stop_iteration(self): 295 def zip_longest(*iterables, fillvalue=None): 296 # Get the iterators for each iterable 297 iterators = [iter(it) for it in iterables] 298 299 result = [] 300 while True: 301 for it in iterators: 302 try: 303 value = next(it) 304 except StopIteration: 305 result.append(fillvalue) 306 return result 307 result.append(value) 308 309 def fn(x, y): 310 torch.cos(torch.randn(4)) 311 return tuple(zip_longest(x, y)) 312 313 x = [1, 2, 3, 4] 314 y = [10, 11, 12] 315 316 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 317 ref = fn(x, y) 318 res = opt_fn(x, y) 319 self.assertEqual(ref, res) 320 321 def test_nn_reraise(self): 322 class M(torch.nn.Module): 323 def forward(self, x): 324 raise ValueError("woof") 325 return x + 2 326 327 m = M() 328 m.register_forward_pre_hook(lambda m, go: None) 329 330 torch._dynamo.utils.clear_compilation_metrics() 331 opt_call = torch.compile(lambda x: m(x), backend="eager") 332 self.assertRaises(ValueError, lambda: opt_call(torch.randn(3))) 333 metrics = torch._dynamo.utils.get_compilation_metrics() 334 self.assertEqual(metrics[0].fail_reason, "Observed exception") 335 336 def test_key_error(self): 337 def fn(x, d): 338 try: 339 a = d["b"] 340 except KeyError: 341 a = 2 342 return x * a 343 344 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 345 x = torch.randn(4) 346 d = {"a": 1} 347 ref = fn(x, d) 348 res = opt_fn(x, d) 349 self.assertEqual(ref, res) 350 351 def test_atrribute_error(self): 352 class Mock: 353 def __init__(self): 354 self.a = 1 355 356 mock = Mock() 357 358 def fn(x): 359 try: 360 c = 2 361 mock.b 362 except AttributeError: 363 c = 3 364 return torch.sin(x) * c 365 366 opt_fn = torch.compile(fn, backend="eager") 367 x = torch.randn(4) 368 ref = fn(x) 369 res = opt_fn(x) 370 self.assertEqual(ref, res) 371 372 def test_raise_from_None(self): 373 # Inspired from os.environ 374 class MyMapping: 375 def __init__(self, d): 376 self._d = d 377 378 def __getitem__(self, key): 379 try: 380 value = self._d[key] 381 except KeyError: 382 raise KeyError(key) from None 383 return value 384 385 d = MyMapping({"a": 10, "b": 20}) 386 387 def mapping_get(obj, key, value=None): 388 try: 389 return obj.__getitem__(key) 390 except KeyError: 391 return value 392 393 def fn(x, d, key): 394 x = torch.sin(x + 1) 395 return x, mapping_get(d, key) 396 397 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 398 399 x = torch.rand(2, 3) 400 ref = fn(x, d, "m") 401 res = opt_fn(x, d, "m") 402 self.assertEqual(ref[0], res[0]) 403 self.assertEqual(ref[1], res[1]) 404 405 406if __name__ == "__main__": 407 from torch._dynamo.test_case import run_tests 408 409 run_tests() 410