1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch.cuda.amp import autocast 5from typing import Optional, Tuple 6 7import unittest 8from test_jit import JitTestCase 9from torch.testing._internal.common_cuda import TEST_CUDA 10from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo 11from torch.testing import FileCheck 12from jit.test_models import MnistNet 13 14TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() 15 16@skipIfTorchDynamo("Not a TorchDynamo suitable test") 17class TestAutocast(JitTestCase): 18 def setUp(self): 19 # common input tensors 20 if TEST_CUDA: 21 self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 22 self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 23 self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 24 self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 25 self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 26 self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 27 self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 28 self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 29 self.old_value = torch._C._jit_set_autocast_mode(True) 30 super().setUp() 31 32 def tearDown(self): 33 torch._C._jit_set_autocast_mode(self.old_value) 34 super().tearDown() 35 36 @unittest.skipIf(not TEST_CUDA, "No cuda") 37 def test_jit_generic_autocast(self): 38 @torch.jit.script 39 def fn_cuda_autocast(a, b): 40 with autocast(): 41 x = torch.mm(a, b) 42 y = torch.sum(x) 43 return x, y 44 45 @torch.jit.script 46 def fn_generic_autocast(a, b): 47 with torch.amp.autocast(device_type='cuda'): 48 x = torch.mm(a, b) 49 y = torch.sum(x) 50 return x, y 51 self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32)) 52 53 @unittest.skipIf(not TEST_CUDA, "No cuda") 54 def test_minimal(self): 55 @torch.jit.script 56 def fn(a, b): 57 with autocast(): 58 x = torch.mm(a, b) 59 y = torch.sum(x) 60 return x, y 61 x, y = fn(self.a_fp32, self.b_fp32) 62 self.assertEqual(x.dtype, torch.float16) 63 self.assertEqual(y.dtype, torch.float32) 64 65 @unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support") 66 def test_linear_bf16(self): 67 @torch.jit.script 68 def fn(a, b): 69 with autocast(dtype=torch.bfloat16): 70 x = torch.mm(a, b) 71 y = torch.sum(x) 72 return x, y 73 x, y = fn(self.a_fp32, self.b_fp32) 74 self.assertEqual(x.dtype, torch.bfloat16) 75 self.assertEqual(y.dtype, torch.float32) 76 77 @unittest.skipIf(not TEST_CUDA, "No cuda") 78 def test_minimal_cpu(self): 79 @torch.jit.script 80 def fn(a, b): 81 with autocast(): 82 return torch.mm(a, b) 83 result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu')) 84 self.assertEqual(result.dtype, torch.float32) 85 86 @unittest.skipIf(not TEST_CUDA, "No cuda") 87 def test_minimal_off(self): 88 @torch.jit.script 89 def fn(a, b): 90 with autocast(enabled=False): 91 return torch.mm(a, b) 92 result = fn(self.a_fp32, self.b_fp32) 93 self.assertEqual(result.dtype, torch.float32) 94 95 @unittest.skipIf(not TEST_CUDA, "No cuda") 96 def test_runtime_autocast_state(self): 97 @torch.jit.script 98 def fn(a, b, use_amp: bool): 99 with autocast(enabled=use_amp): 100 return torch.mm(a, b) 101 # runtime values for autocast enable argument are not supported 102 with self.assertRaises(RuntimeError): 103 fn(self.a_fp32, self.b_fp32, True) 104 105 @unittest.skipIf(not TEST_CUDA, "No cuda") 106 def test_runtime_autocast_state_expr(self): 107 @torch.jit.script 108 def fn(a, b): 109 with autocast(enabled=True if a[0][0] > 0.5 else False): 110 return torch.mm(a, b) 111 # runtime values for autocast enable argument are not supported 112 with self.assertRaises(RuntimeError): 113 fn(self.a_fp32, self.b_fp32) 114 115 @unittest.skipIf(not TEST_CUDA, "No cuda") 116 def test_explicit_casts(self): 117 @torch.jit.script 118 def fn(a, b, c, d): 119 with autocast(): 120 e = torch.mm(a.double(), b.double()).float() 121 f = torch.mm(c, d).double() 122 g = torch.mm(c.double(), f) 123 return e, f, g 124 e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 125 self.assertEqual(e.dtype, torch.float32) 126 self.assertEqual(f.dtype, torch.float64) 127 self.assertEqual(g.dtype, torch.float64) 128 129 # multiple uses of the same input value 130 @unittest.skipIf(not TEST_CUDA, "No cuda") 131 def test_duplicate_inputs(self): 132 @torch.jit.script 133 def fn(a, b): 134 with autocast(): 135 e = torch.mm(a, a) 136 f = torch.mm(e, e) 137 return e, f 138 e, f = fn(self.a_fp32, self.b_fp32) 139 self.assertEqual(e.dtype, torch.float16) 140 self.assertEqual(f.dtype, torch.float16) 141 142 @unittest.skipIf(not TEST_CUDA, "No cuda") 143 def test_fp32_policy(self): 144 @torch.jit.script 145 def fn(a): 146 with autocast(enabled=True): 147 return torch.log(a) 148 result = fn(self.a_fp16) 149 self.assertEqual(result.dtype, torch.float32) 150 151 @unittest.skipIf(not TEST_CUDA, "No cuda") 152 def test_fp32_policy_with_fp64(self): 153 @torch.jit.script 154 def fn(a): 155 with autocast(enabled=True): 156 return torch.log(a) 157 # fp32 policy should not narrow fp64 to fp32! 158 result = fn(self.a_fp32.double()) 159 self.assertEqual(result.dtype, torch.float64) 160 161 @unittest.skipIf(not TEST_CUDA, "No cuda") 162 def test_promote_policy(self): 163 @torch.jit.script 164 def fn(a, b, c, d): 165 with autocast(): 166 e = torch.mm(a, b) 167 f = torch.addcmul(e, c, d, value=0.1) 168 return e, f 169 e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 170 self.assertEqual(e.dtype, torch.float16) 171 self.assertEqual(f.dtype, torch.float32) 172 173 @unittest.skipIf(not TEST_CUDA, "No cuda") 174 def test_promote_policy_fp64(self): 175 @torch.jit.script 176 def fn(a, b): 177 with autocast(enabled=True): 178 return torch.addcmul(a, a, b, value=0.1) 179 result = fn(self.a_fp32.double(), self.b_fp32.double()) 180 self.assertEqual(result.dtype, torch.float64) 181 182 @unittest.skipIf(not TEST_CUDA, "No cuda") 183 def test_fp32_set_opt_dtype_policy(self): 184 @torch.jit.script 185 def fn(a, b, c, d, dtype: Optional[int]): 186 with autocast(enabled=True): 187 x = torch.softmax(a, 0) 188 y = torch.softmax(b, 0, None) 189 z = torch.softmax(c, 0, torch.float64) 190 w = torch.softmax(d, 0, dtype) 191 return x, y, z, w 192 x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None) 193 self.assertEqual(x.dtype, torch.float32) 194 self.assertEqual(y.dtype, torch.float32) 195 self.assertEqual(z.dtype, torch.float64) 196 self.assertEqual(w.dtype, torch.float16) 197 198 @unittest.skipIf(not TEST_CUDA, "No cuda") 199 def test_fp32_set_opt_dtype_policy_fp64(self): 200 @torch.jit.script 201 def fn(a, b, c, d, dtype: Optional[int]): 202 with autocast(enabled=True): 203 x = torch.softmax(a, 0) 204 y = torch.softmax(b, 0, None) 205 z = torch.softmax(c, 0, torch.float64) 206 w = torch.softmax(d, 0, dtype) 207 return x, y, z, w 208 x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None) 209 self.assertEqual(x.dtype, torch.float64) 210 self.assertEqual(y.dtype, torch.float64) 211 self.assertEqual(z.dtype, torch.float64) 212 self.assertEqual(w.dtype, torch.float64) 213 214 @unittest.skipIf(True, "broken due to lack of type propagation") 215 @unittest.skipIf(not TEST_CUDA, "No cuda") 216 def test_control_flow(self): 217 @torch.jit.script 218 def fn(a, b, c, d): 219 with autocast(): 220 if a[0][0] > 0.5: 221 e = torch.mm(a, b) 222 x = 1 223 else: 224 e = torch.mm(c, d) 225 x = 2 226 f = torch.mm(d, e) * x 227 return e, f 228 e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 229 self.assertEqual(e.dtype, torch.float16) 230 self.assertEqual(f.dtype, torch.float16) 231 232 # this works find in regular Python, but it creates a delicate 233 # situation in TorchScript where the types are not consistent across 234 # the then/else branches 235 @unittest.skipIf(not TEST_CUDA, "No cuda") 236 def test_divergent_types(self): 237 @torch.jit.script 238 def fn(a, b, c, d): 239 with autocast(): 240 if a[0][0] > 0.5: 241 e = torch.mm(a, b) 242 f = torch.mm(a, b).float() 243 else: 244 e = torch.mm(c, d).float() 245 f = torch.mm(a, b) 246 return torch.mm(e.float(), f.float()) 247 result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 248 self.assertEqual(result.dtype, torch.float32) 249 250 # another, more complex case of divergent types 251 @unittest.skipIf(not TEST_CUDA, "No cuda") 252 def test_divergent_autocast(self): 253 @torch.jit.script 254 def fn(a, b, c, d): 255 autocast_on = autocast(enabled=True) 256 autocast_off = autocast(enabled=False) 257 if a[0][0] > 0.5: 258 with autocast_on: 259 e = torch.mm(a, b) 260 else: 261 with autocast_off: 262 e = torch.mm(c, d) 263 return torch.mm(e, e) 264 fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 265 266 @unittest.skipIf(not TEST_CUDA, "No cuda") 267 def test_conditional_autocast(self): 268 @torch.jit.script 269 def fn(a, b): 270 autocast_on = autocast(enabled=True) 271 autocast_off = autocast(enabled=False) 272 with autocast_on if a[0][0] > 0.5 else autocast_off: 273 return torch.mm(a, b) 274 # conditional autocast expressions are not supported 275 with self.assertRaises(RuntimeError): 276 fn(self.a_fp32, self.b_fp32) 277 278 @unittest.skipIf(not TEST_CUDA, "No cuda") 279 def test_nested_autocast(self): 280 @torch.jit.script 281 def fn(a, b, c, d): 282 with autocast(enabled=False): 283 e = torch.mm(a, b) 284 with autocast(enabled=True): 285 f = torch.mm(e, c) 286 with autocast(enabled=False): 287 g = torch.mm(e, d) 288 return e, f, g 289 e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 290 self.assertEqual(e.dtype, torch.float32) 291 self.assertEqual(f.dtype, torch.float16) 292 self.assertEqual(g.dtype, torch.float32) 293 294 @unittest.skipIf(not TEST_CUDA, "No cuda") 295 def test_implicitly_nested_autocast(self): 296 @torch.jit.script 297 def fn(a, b): 298 with autocast(enabled=False), autocast(enabled=True): 299 return torch.mm(a, b) 300 result = fn(self.a_fp32, self.b_fp32) 301 self.assertEqual(result.dtype, torch.float16) 302 303 @unittest.skipIf(not TEST_CUDA, "No cuda") 304 def test_reused_autocast(self): 305 @torch.jit.script 306 def fn(a, b, c, d): 307 autocast_instance = autocast(enabled=True) 308 with autocast_instance: 309 e = torch.mm(a, b) 310 with autocast_instance: 311 e = torch.mm(c, d) 312 f = torch.mm(d, e) 313 g = torch.mm(e, f) 314 return e, f, g 315 e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 316 self.assertEqual(e.dtype, torch.float16) 317 self.assertEqual(f.dtype, torch.float16) 318 self.assertEqual(g.dtype, torch.float16) 319 320 # TODO: fix and enable this test? 321 # (we could technically fix this, but is it really worth it?) 322 @unittest.skipIf(True, "unsuported autocast syntax") 323 def test_reused_autocast_expr(self): 324 @torch.jit.script 325 def fn(a, b, c, d): 326 with autocast(enabled=True) as autocast_instance: 327 e = torch.mm(a, b) 328 with autocast_instance: 329 e = torch.mm(c, d) 330 f = torch.mm(d, e) 331 g = torch.mm(e, f) 332 return e, f, g 333 e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 334 self.assertEqual(e.dtype, torch.float16) 335 self.assertEqual(f.dtype, torch.float16) 336 self.assertEqual(g.dtype, torch.float16) 337 338 @unittest.skipIf(not TEST_CUDA, "No cuda") 339 def test_callees(self): 340 def helper(a, b): 341 return torch.mm(a, b) 342 343 @torch.jit.script 344 def fn(a, b): 345 with autocast(enabled=True): 346 tmp = helper(a, b) 347 tmp = helper(tmp, tmp) 348 tmp = helper(tmp, tmp) 349 tmp = helper(tmp, tmp) 350 return helper(tmp, b) 351 352 result = fn(self.a_fp32, self.b_fp32) 353 self.assertEqual(result.dtype, torch.float16) 354 355 @unittest.skipIf(not TEST_CUDA, "No cuda") 356 def test_callees_with_autocast_on(self): 357 def helper(a, b): 358 with autocast(enabled=True): 359 return torch.mm(a, b) 360 361 @torch.jit.script 362 def fn(a, b): 363 with autocast(enabled=False): 364 return helper(a, b) 365 366 result = fn(self.a_fp32, self.b_fp32) 367 self.assertEqual(result.dtype, torch.float16) 368 369 @unittest.skipIf(not TEST_CUDA, "No cuda") 370 def test_callees_with_autocast_off(self): 371 def helper(a, b): 372 with autocast(enabled=False): 373 return torch.mm(a, b) 374 375 @torch.jit.script 376 def fn(a, b): 377 with autocast(enabled=True): 378 return helper(a, b) 379 380 result = fn(self.a_fp32, self.b_fp32) 381 self.assertEqual(result.dtype, torch.float32) 382 383 # scripting inside eager autocast 384 @unittest.skipIf(not TEST_CUDA, "No cuda") 385 def test_eager_and_script(self): 386 @torch.jit.script 387 def fn(a, b): 388 return torch.mm(a, b) 389 for i in range(8): 390 use_autocast = (i % 2 == 0) 391 expected_dtype = torch.float16 if use_autocast else torch.float32 392 with autocast(enabled=use_autocast): 393 result = fn(self.a_fp32, self.b_fp32) 394 self.assertEqual(result.dtype, expected_dtype) 395 396 # traced inside scripting 397 @unittest.skipIf(not TEST_CUDA, "No cuda") 398 def test_script_and_tracing(self): 399 def helper(a, b): 400 return torch.mm(a, b) 401 402 traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) 403 404 @torch.jit.script 405 def fn(a, b): 406 with autocast(enabled=True): 407 return traced(a, b) 408 409 result = fn(self.a_fp32, self.b_fp32) 410 self.assertEqual(result.dtype, torch.float16) 411 412 # traced with autocast inside scripting 413 @unittest.skipIf(True, "autocast(False) is ignored inside traced functions") 414 @unittest.skipIf(not TEST_CUDA, "No cuda") 415 def test_script_and_tracing_with_autocast(self): 416 def helper(a, b): 417 with autocast(enabled=False): 418 return torch.mm(a, b) * 2.0 419 420 traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) 421 422 @torch.jit.script 423 def fn(a, b): 424 with autocast(enabled=True): 425 return traced(a, b) 426 427 result = fn(self.a_fp32, self.b_fp32) 428 self.assertEqual(result.dtype, torch.float32) 429 430 # scripted called from traced 431 @unittest.skipIf(not TEST_CUDA, "No cuda") 432 def test_tracing_and_script(self): 433 @torch.jit.script 434 def fn(a, b): 435 with autocast(): 436 return torch.mm(a, b) 437 438 def traced(a, b): 439 return fn(a, b) 440 441 traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) 442 result = traced(self.a_fp32, self.b_fp32) 443 self.assertEqual(result.dtype, torch.float16) 444 445 # scripted called from traced with autocast 446 @unittest.skipIf(True, "scripted called from traced TorchScript is not yet working") 447 @unittest.skipIf(not TEST_CUDA, "No cuda") 448 def test_tracing_with_autocast_and_script(self): 449 @torch.jit.script 450 def fn(a, b): 451 return torch.mm(a, b) 452 453 def traced(a, b): 454 with autocast(enabled=True): 455 return fn(a, b) 456 457 traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) 458 result = traced(self.a_fp32, self.b_fp32) 459 self.assertEqual(result.dtype, torch.float16) 460 461 @unittest.skipIf(not TEST_CUDA, "No cuda") 462 def test_script_module(self): 463 class TestModule(torch.nn.Module): 464 def __init__(self, N, M): 465 super().__init__() 466 self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32)) 467 self.linear = torch.nn.Linear(N, M).float() 468 469 def forward(self, input): 470 with autocast(enabled=True): 471 output = self.weight.mv(input) 472 output = self.linear(output) 473 return output 474 475 scripted_module = torch.jit.script(TestModule(2, 3)).cuda() 476 input = torch.rand(3, dtype=torch.float32, device='cuda') 477 result = scripted_module(input) 478 self.assertEqual(result.dtype, torch.float16) 479 480 @unittest.skipIf(True, "autocast decorators not supported") 481 @unittest.skipIf(not TEST_CUDA, "No cuda") 482 def test_autocast_decorator(self): 483 @torch.jit.script 484 @autocast(enabled=True) 485 def fn(a, b): 486 return torch.mm(a, b) 487 result = fn(self.a_fp32, self.b_fp32) 488 self.assertEqual(result.dtype, torch.float16) 489 490 # this is equivalent to running scripted functions inside autocast) 491 # (see also test_eager_and_script) 492 @unittest.skipIf(not TEST_CUDA, "No cuda") 493 def test_autocast_decorator_outside_jit(self): 494 @autocast(enabled=True) 495 @torch.jit.script 496 def fn(a, b): 497 return torch.mm(a, b) 498 result = fn(self.a_fp32, self.b_fp32) 499 self.assertEqual(result.dtype, torch.float16) 500 501 @unittest.skipIf(not TEST_CUDA, "No cuda") 502 def test_inplace(self): 503 @torch.jit.script 504 def fn(a, b, c): 505 with autocast(enabled=True): 506 x = torch.addmm(a, b, c) 507 y = torch.addmm(a, b, c, out=a) 508 z = a.addmm_(b, c) 509 return x, y, z 510 x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32) 511 self.assertEqual(x.dtype, torch.float16) 512 self.assertEqual(y.dtype, torch.float32) 513 self.assertEqual(z.dtype, torch.float32) 514 515 def _test_autocast(self, func, cast_op, *args): 516 jit_func = torch.jit.script(func) 517 o = func(*args) 518 jit_o = jit_func(*args) 519 if cast_op is not None: 520 FileCheck().check(cast_op).run(jit_func.graph_for(*args)) 521 for o0, o1 in zip(o, jit_o): 522 self.assertEqual(o0.dtype, o1.dtype) 523 524 @unittest.skipIf(not TEST_CUDA, "No cuda") 525 def test_autocast_api(self): 526 527 def t_autocast_cpu(x, y): 528 with torch.autocast("cpu", dtype=torch.bfloat16): 529 return torch.mm(x, y) 530 531 def t_autocast_cuda(x, y): 532 with torch.autocast("cuda", dtype=torch.half): 533 return torch.mm(x, y) 534 535 def t_cuda_amp_autocast(x, y): 536 with torch.cuda.amp.autocast(): 537 return torch.mm(x, y) 538 539 def t_cpu_amp_autocast(x, y): 540 with torch.cpu.amp.autocast(): 541 return torch.mm(x, y) 542 543 x = torch.randn(5, 5, device="cuda", dtype=torch.float32) 544 y = torch.randn(5, 5, device="cuda", dtype=torch.float32) 545 self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) 546 self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) 547 self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) 548 self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) 549 550 @unittest.skipIf(True, "we need to provide dtype argument at this moment") 551 @unittest.skipIf(not TEST_CUDA, "No cuda") 552 def test_autocast_api_not_supported(self): 553 554 def t_autocast_cpu(x, y): 555 # no dtype provided is not currently supported 556 with torch.autocast("cpu"): 557 return torch.mm(x, y) 558 559 def t_autocast_cuda(x, y): 560 # no dtype provided is not currently supported 561 with torch.autocast("cuda"): 562 return torch.mm(x, y) 563 564 x = torch.randn(5, 5, device="cuda", dtype=torch.float32) 565 y = torch.randn(5, 5, device="cuda", dtype=torch.float32) 566 self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) 567 self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) 568 569 @unittest.skipIf(not TEST_CUDA, "No cuda") 570 def test_autocast_mixed_dtypes(self): 571 572 def t(cpu0, cpu1, cuda0, cuda1): 573 with torch.autocast("cpu", torch.bfloat16): 574 with torch.autocast("cuda", torch.float16): 575 cpu_o = torch.mm(cpu0, cpu1) 576 cuda_o = torch.mm(cuda0, cuda1) 577 return cpu_o, cuda_o 578 579 jit_t = torch.jit.script(t) 580 cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 581 cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 582 cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 583 cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 584 self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 585 586 @unittest.skipIf(not TEST_CUDA, "No cuda") 587 def test_jit_executor_under_autocast(self): 588 589 def t(cpu0, cpu1, cuda0, cuda1): 590 cpu_o = torch.mm(cpu0, cpu1) 591 cuda_o = torch.mm(cuda0, cuda1) 592 return cpu_o, cuda_o 593 594 jit_t = torch.jit.script(t) 595 cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 596 cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 597 cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 598 cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 599 600 with torch.autocast("cpu", torch.bfloat16): 601 with torch.autocast("cuda", torch.float16): 602 self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 603 604 with torch.autocast("cpu", torch.bfloat16): 605 self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 606 607 with torch.autocast("cuda", torch.float16): 608 self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 609 610 # no cast op should be observed when executing outside autocast context 611 self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1) 612 613 @unittest.skipIf(not TEST_CUDA, "No cuda") 614 def test_autocast_autodiff(self): 615 def t(t0, t1): 616 o = torch.mm(t0, t1) 617 return o.relu() 618 619 jit_t = torch.jit.script(t) 620 t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() 621 t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() 622 623 # run optimization 624 for i in range(5): 625 with torch.autocast("cuda", torch.float16): 626 jit_o = jit_t(t0, t1) 627 jit_o.sum().backward() 628 629 t0.grad = None 630 t1.grad = None 631 ref_t0 = t0.detach().requires_grad_() 632 ref_t1 = t1.detach().requires_grad_() 633 634 with torch.autocast("cuda", torch.float16): 635 o = t(ref_t0, ref_t1) 636 jit_o = jit_t(t0, t1) 637 jit_o.sum().backward() 638 o.sum().backward() 639 self.assertEqual(o, jit_o) 640 self.assertEqual(t0.grad, ref_t0.grad) 641 self.assertEqual(t1.grad, ref_t1.grad) 642 self.assertEqual(o.dtype, jit_o.dtype) 643 self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype) 644 self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype) 645 646 @unittest.skipIf(not TEST_CUDA, "No cuda") 647 def test_jit_call_method_under_autocast(self): 648 @torch.jit.interface 649 class Iface(torch.nn.Module): 650 def forward(self, x, y) -> torch.Tensor: 651 pass 652 653 class Impl(Iface): 654 def forward(self, x, y): 655 return torch.mm(x, y) 656 657 class Thing1(torch.nn.Module): 658 impl: Iface 659 660 def forward(self, x, y): 661 with torch.cuda.amp.autocast(): 662 a = torch.mm(x, y) 663 b = self.impl.forward(a, x) 664 return b 665 666 scripted_impl = torch.jit.script(Impl()) 667 thing1 = Thing1() 668 thing1.impl = scripted_impl 669 scripted_thing1 = torch.jit.script(thing1) 670 x = torch.rand([2, 2]) 671 y = torch.rand([2, 2]) 672 673 # make sure this doesn't throw an error 674 with torch.cuda.amp.autocast(): 675 ans = scripted_thing1.forward(x, y) 676 self.assertEqual(torch.mm(torch.mm(x, y), x), ans) 677 678 # sanity check: this isn't supported currently when global autocasting 679 # isn't enabled 680 self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y)) 681 682 @unittest.skipIf(not TEST_CUDA, "No cuda") 683 def test_jit_freeze_autocast_basic(self): 684 class TestModule(torch.nn.Module): 685 def forward(self, x, y): 686 with torch.cuda.amp.autocast(): 687 return torch.mm(x, y) 688 689 x = torch.rand((3, 4), dtype=torch.float).cuda() 690 y = torch.rand((4, 5), dtype=torch.float).cuda() 691 692 mod = TestModule().eval() 693 694 # sanity check 695 self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y) 696 697 frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) 698 FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph) 699 700 # make sure that the runtime pass doesn't duplicate autocast nodes 701 frozen_mod(x, y) 702 optimized_graph = frozen_mod.graph_for(x, y) 703 FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph) 704 705 @unittest.skipIf(not TEST_CUDA, "No cuda") 706 def test_jit_freeze_autocast_constants(self): 707 class TestModule(torch.nn.Module): 708 def __init__(self) -> None: 709 super().__init__() 710 self.x = torch.rand((3, 4), dtype=torch.float).cuda() 711 712 def forward(self, y): 713 with torch.cuda.amp.autocast(): 714 return torch.mm(self.x, y) 715 716 y = torch.rand((4, 5), dtype=torch.float).cuda() 717 mod = TestModule().eval() 718 719 frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) 720 # freezing should pre-cast the constant self.x to remove one autocast call 721 FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph) 722 723 # the runtime autocasting pass will re-insert the second autocast call, 724 # but constant propagation will merge it with the constant that it's casting. 725 frozen_mod(y) 726 optimized_graph = frozen_mod.graph_for(y) 727 FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph) 728 729 @unittest.skipIf(TEST_CUDA, "CPU-only test") 730 def test_jit_autocast_softmax_cpu(self): 731 def fn(x): 732 with torch.cpu.amp.autocast(): 733 return torch.nn.functional.softmax(x, dim=0) 734 735 fn_s = torch.jit.script(fn) 736 x = torch.rand((2, 2), dtype=torch.bfloat16) 737 fn_s(x) 738 y = fn_s(x) 739 740 self.assertTrue(y.dtype == torch.bfloat16) 741 742 @unittest.skipIf(not TEST_CUDA, "No cuda") 743 def test_jit_autocast_softmax_gpu(self): 744 def fn(x): 745 with torch.cuda.amp.autocast(): 746 return torch.nn.functional.softmax(x, dim=0) 747 748 fn_s = torch.jit.script(fn) 749 x = torch.rand((2, 2), dtype=torch.half).cuda() 750 fn_s(x) 751 y = fn_s(x) 752 753 self.assertTrue(y.dtype == torch.float) 754 755 def test_ignore_amp(self): 756 @torch.jit.script 757 def foo(x): 758 return torch.mm(x, x) 759 760 inp = torch.rand([10, 10], dtype=torch.float) 761 foo._set_ignore_amp(True) 762 with torch.cpu.amp.autocast(): 763 foo(inp) 764 foo(inp) 765 766 g = torch.jit.last_executed_optimized_graph() 767 FileCheck().check_not("_autocast_to_reduced").run(g) 768 769class convbn(torch.nn.Module): 770 def __init__(self, bias_enabled=True): 771 super().__init__() 772 self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled) 773 self.bn = torch.nn.BatchNorm2d(64) 774 775 def forward(self, x): 776 return self.bn(self.conv(x)) 777 778@skipIfTorchDynamo("Not a TorchDynamo suitable test") 779class TestJitTraceAutocast(JitTestCase): 780 def setUp(self): 781 super().setUp() 782 self.previous_default_dtype = torch.get_default_dtype() 783 torch.set_default_dtype(torch.float32) 784 self.models = [MnistNet(), 785 convbn(bias_enabled=True), 786 convbn(bias_enabled=False)] 787 self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'), 788 torch.randn(32, 3, 224, 224, device='cpu'), 789 torch.randn(32, 3, 224, 224, device='cpu')] 790 self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False) 791 792 def tearDown(self): 793 torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass) 794 torch.set_default_dtype(self.previous_default_dtype) 795 super().tearDown() 796 797 def test_generate_autocast_jit_trace_model(self): 798 def test_generate_autocast_jit_trace_model(model, x): 799 model.eval() 800 with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): 801 traced_model = torch.jit.trace(model, x) 802 traced_model = torch.jit.freeze(traced_model) 803 for i in range(self.models.__len__()): 804 test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i]) 805 806 def test_nchw_autocast_jit_trace_model(self): 807 def test_nchw_autocast_jit_trace_model(model, x): 808 model.eval() 809 with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): 810 traced_model = torch.jit.trace(model, x) 811 traced_model = torch.jit.freeze(traced_model) 812 with torch.no_grad(): 813 y = traced_model(x.clone()) 814 with torch.cpu.amp.autocast(), torch.no_grad(): 815 y2 = model(x.clone()) 816 torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) 817 for i in range(self.models.__len__()): 818 test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) 819 820 def test_nhwc_autocast_jit_trace_model(self): 821 def test_nhwc_autocast_jit_trace_model(model, x): 822 model = model.to(memory_format=torch.channels_last) 823 model.eval() 824 with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): 825 traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last)) 826 traced_model = torch.jit.freeze(traced_model) 827 with torch.no_grad(): 828 y = traced_model(x.clone().to(memory_format=torch.channels_last)) 829 with torch.cpu.amp.autocast(), torch.no_grad(): 830 y2 = model(x.clone().to(memory_format=torch.channels_last)) 831 torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) 832 for i in range(self.models.__len__()): 833 if self.inputs[i].size().__len__() == 5: 834 # NHWC 3D case not support yet 835 continue 836 test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i]) 837 838 def test_cat_promote(self): 839 class TestModel(torch.nn.Module): 840 def forward(self, a, b): 841 return torch.cat([a, b], 0) 842 843 with torch.jit.fuser("none"): 844 # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs. 845 # To avoid the fusion group from TE, we will disable the fuser here. 846 for jit_freeze_or_not in [False, True]: 847 test_model = TestModel().eval() 848 with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad(): 849 a = torch.rand(24, 128, 128) 850 b = torch.rand(24, 128, 128, dtype=torch.bfloat16) 851 c = test_model(a, b) 852 traced = torch.jit.trace(test_model, (a, b)) 853 if jit_freeze_or_not: 854 traced = torch.jit.freeze(traced) 855 for _ in range(3): 856 c2 = traced(a, b) 857 self.assertTrue(c.dtype, torch.float32) 858 self.assertTrue(c2.dtype, torch.float32) 859 traced_graph = traced.graph_for(a, b) 860 self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes())) 861 862 def test_script_autocast_cpu(self): 863 def fn(x): 864 if torch.is_autocast_cpu_enabled(): 865 return x.relu() 866 else: 867 return x.sin() 868 869 fn_s = torch.jit.script(fn) 870 871 x = torch.rand((4, 4)) - 0.5 872 with torch.cpu.amp.autocast(): 873 self.assertEqual(fn_s(x), fn(x)) 874 875 with torch.cpu.amp.autocast(enabled=True): 876 self.assertEqual(fn_s(x), fn(x)) 877 878 self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes())) 879 880 @unittest.skipIf(not TEST_CUDA, "No cuda") 881 def test_script_autocast_cuda(self): 882 def fn(x): 883 if torch.is_autocast_enabled(): 884 return x.relu() 885 else: 886 return x.sin() 887 888 fn_s = torch.jit.script(fn) 889 890 x = torch.rand((4, 4)) - 0.5 891 with torch.cpu.amp.autocast(): 892 self.assertEqual(fn_s(x), fn(x)) 893 894 with torch.cuda.amp.autocast(enabled=True): 895 self.assertEqual(fn_s(x), fn(x)) 896 897 self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes())) 898 899 900 def test_scripted_aliasing(self): 901 # torch.is_autocast_enabled should not be able to move inside of the autocast context. 902 def fn(x): 903 if torch.is_autocast_enabled(): 904 y = True 905 else: 906 y = False 907 with torch.cuda.amp.autocast(enabled=True): 908 z = x.relu() 909 return y, z 910 911 fn_s = torch.jit.script(fn) 912 graph = fn_s.graph 913 914 aliasdb = graph.alias_db() 915 916 is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled") 917 enter_nodes = graph.findAllNodes("prim::Enter") 918 919 self.assertEqual(len(is_enabled_nodes), 1) 920 self.assertEqual(len(enter_nodes), 1) 921 922 self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0])) 923 924 925 def test_script_autocast_enable_and_check(self): 926 def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]: 927 b1 = torch.is_autocast_cpu_enabled() 928 v1 = torch.mm(x, y) 929 with torch.cpu.amp.autocast(enabled=True): 930 b2 = torch.is_autocast_cpu_enabled() 931 v2 = torch.mm(x, y) 932 with torch.cpu.amp.autocast(enabled=False): 933 b3 = torch.is_autocast_cpu_enabled() 934 v3 = torch.mm(x, y) 935 return (v1, b1, v2, b2, v3, b3) 936 937 # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float 938 def check_fn_results(arr): 939 [v1, b1, v2, b2, v3, b3] = arr 940 self.assertTrue((v1.dtype == torch.float) != b1) 941 self.assertTrue((v2.dtype == torch.float) != b2) 942 self.assertTrue((v3.dtype == torch.float) != b3) 943 944 x = torch.rand((2, 2), dtype=torch.float) 945 y = torch.rand((2, 2), dtype=torch.float) 946 947 fn_s = torch.jit.script(fn) 948 949 with torch.cpu.amp.autocast(enabled=False): 950 check_fn_results(fn(x, y)) 951 check_fn_results(fn_s(x, y)) 952 953 with torch.cpu.amp.autocast(enabled=True): 954 check_fn_results(fn(x, y)) 955 check_fn_results(fn_s(x, y)) 956 957 958if __name__ == "__main__": 959 run_tests() 960