1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import Any, List 6 7import torch 8from torch.testing._internal.common_utils import skipIfTorchDynamo 9from torch.testing._internal.jit_utils import JitTestCase, make_global 10 11 12# Make the helper files in test/ importable 13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14sys.path.append(pytorch_test_dir) 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24class TestWith(JitTestCase): 25 """ 26 A suite of tests for with statements. 27 """ 28 29 def test_with_as(self): 30 """ 31 Check that with statements that use the 'as' keyword to bind expressions 32 to targets work as expected. 33 """ 34 35 @torch.jit.script 36 class Context: 37 """ 38 This class implements a basic context manager interface for use in 39 the unit tests. Unlike Context, the stateful part of this class 40 is a Tensor that is mutated in-place so that modifications made in the 41 JIT interpreter are visible outside of it. 42 """ 43 44 def __init__(self, start: int): 45 self.count = torch.tensor([start], dtype=torch.double) 46 47 def __enter__(self): 48 self.count.add_(0.3) 49 return self.count 50 51 def __exit__(self, type: Any, value: Any, tb: Any) -> bool: 52 self.count.sub_(0.3) 53 return True 54 55 make_global(Context) 56 57 def test_basic(x: torch.Tensor) -> torch.Tensor: 58 """Basic test with one with-statement.""" 59 60 c = Context(1) 61 62 with c as mult: 63 y = x + mult 64 65 y *= c.count 66 return y 67 68 def test_pass(x: torch.Tensor) -> torch.Tensor: 69 """ 70 Test with a pass statement inside a with-statement. Although 71 the body of the with is empty, __enter__ and __exit__ should 72 still be called. 73 """ 74 c = Context(1) 75 76 with c as mult: 77 pass 78 79 x *= c.count 80 return x 81 82 def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 83 """ 84 Test that returning early from inside a with-statement works 85 as expected. 86 """ 87 with c as mult: 88 y = x + mult 89 return y 90 91 x = y + y 92 return x 93 94 def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 95 """ 96 Test that conditionally returning early from inside a with-statement works 97 as expected. 98 """ 99 with c as mult: 100 y = x + mult 101 if mult > 0: 102 return y 103 104 x = y + y 105 return x 106 107 def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 108 """ 109 Test that breaking early from inside a with-statement works 110 as expected. 111 """ 112 with c as mult: 113 for a in l: 114 if a == 0: 115 break 116 x += a * mult 117 118 return x 119 120 def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 121 """ 122 Test that using continue inside a with-statement works 123 as expected. 124 """ 125 with c as mult: 126 for a in l: 127 if a == 0: 128 continue 129 x += a * mult 130 131 return x 132 133 def test_serial(x: torch.Tensor) -> torch.Tensor: 134 """ 135 Test two with-statements in a row. 136 """ 137 c = Context(1) 138 139 with c as mult: 140 y = x + mult 141 142 with c as mult: 143 y *= mult 144 145 return y 146 147 def test_nested(x: torch.Tensor) -> torch.Tensor: 148 """ 149 Test nested with-statements. 150 """ 151 c = Context(1) 152 153 with c as m: 154 with c as n: 155 y = x + n 156 157 y *= m 158 159 return y 160 161 def test_combined(x: torch.Tensor) -> torch.Tensor: 162 """ 163 Test a with-statement with multiple with items. 164 """ 165 c = Context(1) 166 d = Context(2) 167 168 with c as m, d as n: 169 y = x + (m + n) 170 171 return y 172 173 test_input = torch.randn(2, 2) 174 test_context = Context(2) 175 test_list = [2, 0, 1, 3, 0, 2] 176 177 self.checkScript(test_basic, (test_input,)) 178 self.checkScript(test_pass, (test_input,)) 179 self.checkScript(test_early_return, (test_input, test_context)) 180 self.checkScript(test_break, (test_input, test_context, test_list)) 181 self.checkScript(test_continue, (test_input, test_context, test_list)) 182 self.assertEqual(test_context.count, 2) 183 self.checkScript(test_serial, (test_input,)) 184 self.checkScript(test_nested, (test_input,)) 185 self.checkScript(test_combined, (test_input,)) 186 187 def test_with_no_as(self): 188 """ 189 Check that with statements that do not use the 'as' keyword to bind expressions 190 to targets work as expected. 191 """ 192 193 @torch.jit.script 194 class Context: 195 """ 196 This class implements a basic context manager interface for use in 197 the unit tests. Unlike Context, the stateful part of this class 198 is a Tensor that is mutated in-place so that modifications made in the 199 JIT interpreter are visible outside of it. 200 """ 201 202 def __init__(self, start: int): 203 self.count = torch.tensor([start], dtype=torch.double) 204 205 def __enter__(self): 206 self.count.add_(0.3) 207 return self.count 208 209 def __exit__(self, type: Any, value: Any, tb: Any): 210 self.count.sub_(0.3) 211 212 make_global(Context) 213 214 def test_basic(x: torch.Tensor) -> torch.Tensor: 215 """Basic test with one with-statement.""" 216 217 c = Context(1) 218 219 with c: 220 y = x + c.count 221 222 y *= c.count 223 return y 224 225 def test_pass(x: torch.Tensor) -> torch.Tensor: 226 """ 227 Test with a pass statement inside a with-statement. Although 228 the body of the with is empty, __enter__ and __exit__ should 229 still be called. 230 """ 231 c = Context(1) 232 233 with c: 234 pass 235 236 x *= c.count 237 return x 238 239 def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 240 """ 241 Test that returning early from inside a with-statement works 242 as expected. 243 """ 244 with c: 245 y = x + c.count 246 return y 247 248 x = y + y 249 return x 250 251 def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 252 """ 253 Test that conditionally returning early from inside a with-statement works 254 as expected. 255 """ 256 with c: 257 y = x + c.count 258 if c.count > 0: 259 return y 260 261 x = y + y 262 return x 263 264 def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 265 """ 266 Test that breaking early from inside a with-statement works 267 as expected. 268 """ 269 with c: 270 for a in l: 271 if a == 0: 272 break 273 x += a * c.count 274 275 return x 276 277 def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 278 """ 279 Test that using continue inside a with-statement works 280 as expected. 281 """ 282 with c: 283 for a in l: 284 if a == 0: 285 continue 286 x += a * c.count 287 288 return x 289 290 def test_serial(x: torch.Tensor) -> torch.Tensor: 291 """ 292 Test two with-statements in a row. 293 """ 294 c = Context(1) 295 296 with c: 297 y = x + c.count 298 299 with c: 300 y *= c.count 301 302 return y 303 304 def test_nested(x: torch.Tensor) -> torch.Tensor: 305 """ 306 Test nested with-statements. 307 """ 308 c = Context(1) 309 310 with c: 311 with c: 312 y = x + c.count 313 314 y *= c.count 315 316 return y 317 318 def test_combined(x: torch.Tensor) -> torch.Tensor: 319 """ 320 Test a with-statement with multiple with items. 321 """ 322 c = Context(1) 323 d = Context(2) 324 325 with c, d: 326 y = x + (c.count + d.count) 327 328 return y 329 330 test_input = torch.randn(2, 2) 331 test_context = Context(2) 332 test_list = [2, 0, 1, 3, 0, 2] 333 334 self.checkScript(test_basic, (test_input,)) 335 self.checkScript(test_pass, (test_input,)) 336 self.checkScript(test_early_return, (test_input, test_context)) 337 self.checkScript(test_break, (test_input, test_context, test_list)) 338 self.checkScript(test_continue, (test_input, test_context, test_list)) 339 self.assertEqual(test_context.count, 2) 340 self.checkScript(test_serial, (test_input,)) 341 self.checkScript(test_nested, (test_input,)) 342 self.checkScript(test_combined, (test_input,)) 343 344 def test_with_exceptions(self): 345 """ 346 Check that exceptions thrown in the bodies of with-statements are 347 handled correctly. 348 """ 349 350 @torch.jit.script 351 class Context: 352 """ 353 This class implements a basic context manager interface for use in 354 the unit tests. Unlike Context, the stateful part of this class 355 is a Tensor that is mutated in-place so that modifications made in the 356 JIT interpreter are visible outside of it. 357 """ 358 359 def __init__(self, start: int): 360 self.count = torch.tensor([start], dtype=torch.double) 361 362 def __enter__(self): 363 self.count.add_(0.3) 364 return self.count 365 366 def __exit__(self, type: Any, value: Any, tb: Any): 367 self.count.sub_(0.3) 368 369 make_global(Context) 370 371 @torch.jit.script 372 def method_that_raises() -> torch.Tensor: 373 raise Exception("raised exception") # noqa: TRY002 374 375 @torch.jit.script 376 def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: 377 """ 378 Test the case in which an exception is thrown while executing the body of a with-statement. 379 """ 380 with c as _: 381 x += method_that_raises() 382 383 return x 384 385 @torch.jit.script 386 def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: 387 """ 388 Test the case in which an exception is thrown while executing the body of a nested with-statement. 389 """ 390 with c as _: 391 with c as _: 392 x += method_that_raises() 393 394 return x 395 396 @torch.jit.script 397 def with_that_raises(c: Context) -> torch.Tensor: 398 a = torch.tensor([1]) 399 400 with c as _: 401 a += method_that_raises() 402 403 return a 404 405 @torch.jit.script 406 def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor: 407 """ 408 Test the case in which an exception is thrown while there are active with-statements in two different 409 frames. 410 """ 411 with c as _: 412 x += with_that_raises(c) 413 414 return x 415 416 c = Context(1) 417 418 # checkScript and checkScriptRaisesRegex cannot be used because the string frontend will 419 # not compile class types (of which Context, the context manager being used for this test 420 # is one). 421 with self.assertRaisesRegexWithHighlight( 422 Exception, r"raised exception", 'raise Exception("raised exception' 423 ): 424 test_exception(torch.randn(2), c) 425 self.assertEqual(c.count, 1) 426 427 with self.assertRaisesRegexWithHighlight( 428 Exception, r"raised exception", 'raise Exception("raised exception' 429 ): 430 test_exception_nested(torch.randn(2), c) 431 self.assertEqual(c.count, 1) 432 433 with self.assertRaisesRegexWithHighlight( 434 Exception, r"raised exception", 'raise Exception("raised exception' 435 ): 436 test_exception_fn_call(torch.randn(2), c) 437 self.assertEqual(c.count, 1) 438 439 def test_with_errors(self): 440 """ 441 Check that errors related to with-statements are detected and reported correctly. 442 """ 443 444 @torch.jit.script 445 class NoEnterNoExit: 446 """ 447 This class is missing __enter__ and __exit__ methods. 448 """ 449 450 def __init__(self) -> None: 451 self.count = 1 452 453 @torch.jit.script 454 class BadEnter: 455 """ 456 This class has an __enter__ method with an incorrect signature. 457 """ 458 459 def __init__(self) -> None: 460 self.count = 1 461 462 def __enter__(self, incr: int): # noqa: PLE0302 463 self.count += incr 464 465 def __exit__(self, type: Any, value: Any, tb: Any): 466 pass 467 468 @torch.jit.script 469 class BadExit: 470 """ 471 This class has an __exit__ method with an incorrect signature. 472 """ 473 474 def __init__(self) -> None: 475 self.count = 1 476 477 def __enter__(self): 478 self.count += 1 479 480 def __exit__(self, type: Any, value: Any): # noqa: PLE0302 481 pass 482 483 @torch.jit.script 484 class ExitIncorrectTypes: 485 """ 486 This class has an __exit__ method with unsupported argument types. 487 """ 488 489 def __init__(self) -> None: 490 self.count = 1 491 492 def __enter__(self): 493 self.count += 1 494 495 def __exit__(self, type: Any, value: int, tb: int): 496 pass 497 498 def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor: 499 with cm as _: 500 pass 501 502 return x 503 504 def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor: 505 with cm as _: 506 pass 507 508 return x 509 510 def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor: 511 with cm as _: 512 pass 513 514 return x 515 516 def test_exit_incorrect_types( 517 x: torch.Tensor, cm: ExitIncorrectTypes 518 ) -> torch.Tensor: 519 with cm as _: 520 pass 521 522 return x 523 524 def test_enter_without_object(): 525 with "not_object" as obj: 526 pass 527 528 test_tensor = torch.randn(5, dtype=torch.double) 529 530 with self.assertRaisesRegexWithHighlight( 531 RuntimeError, r"does not define __enter__ and __exit__ methods", "cm" 532 ): 533 self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit())) 534 535 with self.assertRaisesRegexWithHighlight( 536 RuntimeError, 537 r"__enter__ must have only one argument and one return value", 538 "cm", 539 ): 540 self.checkScript(test_bad_enter, (test_tensor, BadEnter())) 541 542 with self.assertRaisesRegexWithHighlight( 543 RuntimeError, r"__exit__ must have four arguments", "cm" 544 ): 545 self.checkScript(test_bad_exit, (test_tensor, BadExit())) 546 547 with self.assertRaisesRegexWithHighlight( 548 RuntimeError, r"argument 2 of __exit__ must have Any type", "cm" 549 ): 550 self.checkScript( 551 test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes()) 552 ) 553 554 with self.assertRaisesRegexWithHighlight( 555 RuntimeError, r"must return an object", '"not_object"' 556 ): 557 self.checkScript(test_enter_without_object, ()) 558 559 def test_with_no_grad(self): 560 """ 561 Check that torch.no_grad() works. Most of these are adapted from 562 corresponding tests for eager-mode no_grad. 563 """ 564 565 # Basic no_grad test. 566 def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 567 with torch.no_grad(): 568 w = x + y 569 570 return w 571 572 s = torch.jit.script(test_no_grad) 573 x = torch.ones(5, 5, requires_grad=True) 574 y = torch.ones(5, 5) * 4 575 w = s(x, y) 576 577 self.assertFalse(w.requires_grad) 578 self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) 579 self.assertIsNone(w.grad_fn) 580 581 # Test assignment of a grad-less Tensor to a Tensor with gradients 582 # in a no_grad block. 583 def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 584 with torch.no_grad(): 585 x[0] = y 586 587 return x 588 589 s = torch.jit.script(test_no_grad_assignment) 590 z = torch.randn(5) 591 w = s(x, z) 592 self.assertTrue(w.requires_grad) 593 self.assertIsNone(w.grad_fn) 594 595 # Check that @torch.jit.ignored functions respect no_grad when it is 596 # called in JIT mode. 597 class NoGradModule(torch.nn.Module): 598 @torch.jit.ignore 599 def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 600 w = x + y 601 return w 602 603 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 604 with torch.no_grad(): 605 w = self.adder(x, y) 606 607 return w 608 609 s = torch.jit.script(NoGradModule()) 610 w = s(x, y) 611 612 self.assertFalse(w.requires_grad) 613 614 @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") 615 def test_with_record_function(self): 616 """ 617 Check that torch.autograd.profiler.record_function context manager is 618 torchscriptable. 619 """ 620 621 def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 622 with torch.autograd.profiler.record_function("foo"): 623 # Nested record_function. 624 with torch.autograd.profiler.record_function("nested"): 625 a = x + y 626 return a 627 628 scripted = torch.jit.script(with_rf) 629 x, y = torch.ones(2), torch.ones(2) 630 with torch.autograd.profiler.profile() as p: 631 scripted(x, y) 632 633 # Need to call below to populate CPU children. 634 p.key_averages() 635 function_events = p.function_events 636 # Event with name "foo" should be recorded. 637 rf_events = [evt for evt in function_events if evt.name == "foo"] 638 self.assertEqual(len(rf_events), 1) 639 rf_event = rf_events[0] 640 child_events = rf_event.cpu_children 641 # Ensure we find nested record_function event 642 self.assertTrue("nested" in (child.name for child in child_events)) 643 nested_function_event = [ 644 evt for evt in function_events if evt.name == "nested" 645 ][0] 646 # Nested record function should have child "aten::add" 647 nested_child_events = nested_function_event.cpu_children 648 self.assertTrue("aten::add" in (child.name for child in nested_child_events)) 649