1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda.amp import autocast 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport unittest 8*da0073e9SAndroid Build Coastguard Workerfrom test_jit import JitTestCase 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 12*da0073e9SAndroid Build Coastguard Workerfrom jit.test_models import MnistNet 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard WorkerTEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test") 17*da0073e9SAndroid Build Coastguard Workerclass TestAutocast(JitTestCase): 18*da0073e9SAndroid Build Coastguard Worker def setUp(self): 19*da0073e9SAndroid Build Coastguard Worker # common input tensors 20*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 21*da0073e9SAndroid Build Coastguard Worker self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 22*da0073e9SAndroid Build Coastguard Worker self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 23*da0073e9SAndroid Build Coastguard Worker self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 24*da0073e9SAndroid Build Coastguard Worker self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') 25*da0073e9SAndroid Build Coastguard Worker self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 26*da0073e9SAndroid Build Coastguard Worker self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 27*da0073e9SAndroid Build Coastguard Worker self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 28*da0073e9SAndroid Build Coastguard Worker self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') 29*da0073e9SAndroid Build Coastguard Worker self.old_value = torch._C._jit_set_autocast_mode(True) 30*da0073e9SAndroid Build Coastguard Worker super().setUp() 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 33*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_autocast_mode(self.old_value) 34*da0073e9SAndroid Build Coastguard Worker super().tearDown() 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 37*da0073e9SAndroid Build Coastguard Worker def test_jit_generic_autocast(self): 38*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 39*da0073e9SAndroid Build Coastguard Worker def fn_cuda_autocast(a, b): 40*da0073e9SAndroid Build Coastguard Worker with autocast(): 41*da0073e9SAndroid Build Coastguard Worker x = torch.mm(a, b) 42*da0073e9SAndroid Build Coastguard Worker y = torch.sum(x) 43*da0073e9SAndroid Build Coastguard Worker return x, y 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 46*da0073e9SAndroid Build Coastguard Worker def fn_generic_autocast(a, b): 47*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast(device_type='cuda'): 48*da0073e9SAndroid Build Coastguard Worker x = torch.mm(a, b) 49*da0073e9SAndroid Build Coastguard Worker y = torch.sum(x) 50*da0073e9SAndroid Build Coastguard Worker return x, y 51*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32)) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 54*da0073e9SAndroid Build Coastguard Worker def test_minimal(self): 55*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 56*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 57*da0073e9SAndroid Build Coastguard Worker with autocast(): 58*da0073e9SAndroid Build Coastguard Worker x = torch.mm(a, b) 59*da0073e9SAndroid Build Coastguard Worker y = torch.sum(x) 60*da0073e9SAndroid Build Coastguard Worker return x, y 61*da0073e9SAndroid Build Coastguard Worker x, y = fn(self.a_fp32, self.b_fp32) 62*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.float16) 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, torch.float32) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support") 66*da0073e9SAndroid Build Coastguard Worker def test_linear_bf16(self): 67*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 68*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 69*da0073e9SAndroid Build Coastguard Worker with autocast(dtype=torch.bfloat16): 70*da0073e9SAndroid Build Coastguard Worker x = torch.mm(a, b) 71*da0073e9SAndroid Build Coastguard Worker y = torch.sum(x) 72*da0073e9SAndroid Build Coastguard Worker return x, y 73*da0073e9SAndroid Build Coastguard Worker x, y = fn(self.a_fp32, self.b_fp32) 74*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.bfloat16) 75*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, torch.float32) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 78*da0073e9SAndroid Build Coastguard Worker def test_minimal_cpu(self): 79*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 80*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 81*da0073e9SAndroid Build Coastguard Worker with autocast(): 82*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 83*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu')) 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float32) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 87*da0073e9SAndroid Build Coastguard Worker def test_minimal_off(self): 88*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 89*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 90*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=False): 91*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 92*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 93*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float32) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 96*da0073e9SAndroid Build Coastguard Worker def test_runtime_autocast_state(self): 97*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 98*da0073e9SAndroid Build Coastguard Worker def fn(a, b, use_amp: bool): 99*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=use_amp): 100*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 101*da0073e9SAndroid Build Coastguard Worker # runtime values for autocast enable argument are not supported 102*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 103*da0073e9SAndroid Build Coastguard Worker fn(self.a_fp32, self.b_fp32, True) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 106*da0073e9SAndroid Build Coastguard Worker def test_runtime_autocast_state_expr(self): 107*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 108*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 109*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True if a[0][0] > 0.5 else False): 110*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 111*da0073e9SAndroid Build Coastguard Worker # runtime values for autocast enable argument are not supported 112*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 113*da0073e9SAndroid Build Coastguard Worker fn(self.a_fp32, self.b_fp32) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 116*da0073e9SAndroid Build Coastguard Worker def test_explicit_casts(self): 117*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 118*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 119*da0073e9SAndroid Build Coastguard Worker with autocast(): 120*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a.double(), b.double()).float() 121*da0073e9SAndroid Build Coastguard Worker f = torch.mm(c, d).double() 122*da0073e9SAndroid Build Coastguard Worker g = torch.mm(c.double(), f) 123*da0073e9SAndroid Build Coastguard Worker return e, f, g 124*da0073e9SAndroid Build Coastguard Worker e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.dtype, torch.float32) 126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.dtype, torch.float64) 127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.dtype, torch.float64) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker # multiple uses of the same input value 130*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 131*da0073e9SAndroid Build Coastguard Worker def test_duplicate_inputs(self): 132*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 133*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 134*da0073e9SAndroid Build Coastguard Worker with autocast(): 135*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, a) 136*da0073e9SAndroid Build Coastguard Worker f = torch.mm(e, e) 137*da0073e9SAndroid Build Coastguard Worker return e, f 138*da0073e9SAndroid Build Coastguard Worker e, f = fn(self.a_fp32, self.b_fp32) 139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.dtype, torch.float16) 140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.dtype, torch.float16) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 143*da0073e9SAndroid Build Coastguard Worker def test_fp32_policy(self): 144*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 145*da0073e9SAndroid Build Coastguard Worker def fn(a): 146*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 147*da0073e9SAndroid Build Coastguard Worker return torch.log(a) 148*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp16) 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float32) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 152*da0073e9SAndroid Build Coastguard Worker def test_fp32_policy_with_fp64(self): 153*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 154*da0073e9SAndroid Build Coastguard Worker def fn(a): 155*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 156*da0073e9SAndroid Build Coastguard Worker return torch.log(a) 157*da0073e9SAndroid Build Coastguard Worker # fp32 policy should not narrow fp64 to fp32! 158*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32.double()) 159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float64) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 162*da0073e9SAndroid Build Coastguard Worker def test_promote_policy(self): 163*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 164*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 165*da0073e9SAndroid Build Coastguard Worker with autocast(): 166*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, b) 167*da0073e9SAndroid Build Coastguard Worker f = torch.addcmul(e, c, d, value=0.1) 168*da0073e9SAndroid Build Coastguard Worker return e, f 169*da0073e9SAndroid Build Coastguard Worker e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.dtype, torch.float16) 171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.dtype, torch.float32) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 174*da0073e9SAndroid Build Coastguard Worker def test_promote_policy_fp64(self): 175*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 176*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 177*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 178*da0073e9SAndroid Build Coastguard Worker return torch.addcmul(a, a, b, value=0.1) 179*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32.double(), self.b_fp32.double()) 180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float64) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 183*da0073e9SAndroid Build Coastguard Worker def test_fp32_set_opt_dtype_policy(self): 184*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 185*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d, dtype: Optional[int]): 186*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 187*da0073e9SAndroid Build Coastguard Worker x = torch.softmax(a, 0) 188*da0073e9SAndroid Build Coastguard Worker y = torch.softmax(b, 0, None) 189*da0073e9SAndroid Build Coastguard Worker z = torch.softmax(c, 0, torch.float64) 190*da0073e9SAndroid Build Coastguard Worker w = torch.softmax(d, 0, dtype) 191*da0073e9SAndroid Build Coastguard Worker return x, y, z, w 192*da0073e9SAndroid Build Coastguard Worker x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None) 193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.float32) 194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, torch.float32) 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.dtype, torch.float64) 196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.dtype, torch.float16) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 199*da0073e9SAndroid Build Coastguard Worker def test_fp32_set_opt_dtype_policy_fp64(self): 200*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 201*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d, dtype: Optional[int]): 202*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 203*da0073e9SAndroid Build Coastguard Worker x = torch.softmax(a, 0) 204*da0073e9SAndroid Build Coastguard Worker y = torch.softmax(b, 0, None) 205*da0073e9SAndroid Build Coastguard Worker z = torch.softmax(c, 0, torch.float64) 206*da0073e9SAndroid Build Coastguard Worker w = torch.softmax(d, 0, dtype) 207*da0073e9SAndroid Build Coastguard Worker return x, y, z, w 208*da0073e9SAndroid Build Coastguard Worker x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None) 209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.float64) 210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, torch.float64) 211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.dtype, torch.float64) 212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.dtype, torch.float64) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "broken due to lack of type propagation") 215*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 216*da0073e9SAndroid Build Coastguard Worker def test_control_flow(self): 217*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 218*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 219*da0073e9SAndroid Build Coastguard Worker with autocast(): 220*da0073e9SAndroid Build Coastguard Worker if a[0][0] > 0.5: 221*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, b) 222*da0073e9SAndroid Build Coastguard Worker x = 1 223*da0073e9SAndroid Build Coastguard Worker else: 224*da0073e9SAndroid Build Coastguard Worker e = torch.mm(c, d) 225*da0073e9SAndroid Build Coastguard Worker x = 2 226*da0073e9SAndroid Build Coastguard Worker f = torch.mm(d, e) * x 227*da0073e9SAndroid Build Coastguard Worker return e, f 228*da0073e9SAndroid Build Coastguard Worker e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.dtype, torch.float16) 230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.dtype, torch.float16) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker # this works find in regular Python, but it creates a delicate 233*da0073e9SAndroid Build Coastguard Worker # situation in TorchScript where the types are not consistent across 234*da0073e9SAndroid Build Coastguard Worker # the then/else branches 235*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 236*da0073e9SAndroid Build Coastguard Worker def test_divergent_types(self): 237*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 238*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 239*da0073e9SAndroid Build Coastguard Worker with autocast(): 240*da0073e9SAndroid Build Coastguard Worker if a[0][0] > 0.5: 241*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, b) 242*da0073e9SAndroid Build Coastguard Worker f = torch.mm(a, b).float() 243*da0073e9SAndroid Build Coastguard Worker else: 244*da0073e9SAndroid Build Coastguard Worker e = torch.mm(c, d).float() 245*da0073e9SAndroid Build Coastguard Worker f = torch.mm(a, b) 246*da0073e9SAndroid Build Coastguard Worker return torch.mm(e.float(), f.float()) 247*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float32) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker # another, more complex case of divergent types 251*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 252*da0073e9SAndroid Build Coastguard Worker def test_divergent_autocast(self): 253*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 254*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 255*da0073e9SAndroid Build Coastguard Worker autocast_on = autocast(enabled=True) 256*da0073e9SAndroid Build Coastguard Worker autocast_off = autocast(enabled=False) 257*da0073e9SAndroid Build Coastguard Worker if a[0][0] > 0.5: 258*da0073e9SAndroid Build Coastguard Worker with autocast_on: 259*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, b) 260*da0073e9SAndroid Build Coastguard Worker else: 261*da0073e9SAndroid Build Coastguard Worker with autocast_off: 262*da0073e9SAndroid Build Coastguard Worker e = torch.mm(c, d) 263*da0073e9SAndroid Build Coastguard Worker return torch.mm(e, e) 264*da0073e9SAndroid Build Coastguard Worker fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 267*da0073e9SAndroid Build Coastguard Worker def test_conditional_autocast(self): 268*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 269*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 270*da0073e9SAndroid Build Coastguard Worker autocast_on = autocast(enabled=True) 271*da0073e9SAndroid Build Coastguard Worker autocast_off = autocast(enabled=False) 272*da0073e9SAndroid Build Coastguard Worker with autocast_on if a[0][0] > 0.5 else autocast_off: 273*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 274*da0073e9SAndroid Build Coastguard Worker # conditional autocast expressions are not supported 275*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 276*da0073e9SAndroid Build Coastguard Worker fn(self.a_fp32, self.b_fp32) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 279*da0073e9SAndroid Build Coastguard Worker def test_nested_autocast(self): 280*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 281*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 282*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=False): 283*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, b) 284*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 285*da0073e9SAndroid Build Coastguard Worker f = torch.mm(e, c) 286*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=False): 287*da0073e9SAndroid Build Coastguard Worker g = torch.mm(e, d) 288*da0073e9SAndroid Build Coastguard Worker return e, f, g 289*da0073e9SAndroid Build Coastguard Worker e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.dtype, torch.float32) 291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.dtype, torch.float16) 292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.dtype, torch.float32) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 295*da0073e9SAndroid Build Coastguard Worker def test_implicitly_nested_autocast(self): 296*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 297*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 298*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=False), autocast(enabled=True): 299*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 300*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 304*da0073e9SAndroid Build Coastguard Worker def test_reused_autocast(self): 305*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 306*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 307*da0073e9SAndroid Build Coastguard Worker autocast_instance = autocast(enabled=True) 308*da0073e9SAndroid Build Coastguard Worker with autocast_instance: 309*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, b) 310*da0073e9SAndroid Build Coastguard Worker with autocast_instance: 311*da0073e9SAndroid Build Coastguard Worker e = torch.mm(c, d) 312*da0073e9SAndroid Build Coastguard Worker f = torch.mm(d, e) 313*da0073e9SAndroid Build Coastguard Worker g = torch.mm(e, f) 314*da0073e9SAndroid Build Coastguard Worker return e, f, g 315*da0073e9SAndroid Build Coastguard Worker e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.dtype, torch.float16) 317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.dtype, torch.float16) 318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.dtype, torch.float16) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker # TODO: fix and enable this test? 321*da0073e9SAndroid Build Coastguard Worker # (we could technically fix this, but is it really worth it?) 322*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "unsuported autocast syntax") 323*da0073e9SAndroid Build Coastguard Worker def test_reused_autocast_expr(self): 324*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 325*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 326*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True) as autocast_instance: 327*da0073e9SAndroid Build Coastguard Worker e = torch.mm(a, b) 328*da0073e9SAndroid Build Coastguard Worker with autocast_instance: 329*da0073e9SAndroid Build Coastguard Worker e = torch.mm(c, d) 330*da0073e9SAndroid Build Coastguard Worker f = torch.mm(d, e) 331*da0073e9SAndroid Build Coastguard Worker g = torch.mm(e, f) 332*da0073e9SAndroid Build Coastguard Worker return e, f, g 333*da0073e9SAndroid Build Coastguard Worker e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.dtype, torch.float16) 335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.dtype, torch.float16) 336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.dtype, torch.float16) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 339*da0073e9SAndroid Build Coastguard Worker def test_callees(self): 340*da0073e9SAndroid Build Coastguard Worker def helper(a, b): 341*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 344*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 345*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 346*da0073e9SAndroid Build Coastguard Worker tmp = helper(a, b) 347*da0073e9SAndroid Build Coastguard Worker tmp = helper(tmp, tmp) 348*da0073e9SAndroid Build Coastguard Worker tmp = helper(tmp, tmp) 349*da0073e9SAndroid Build Coastguard Worker tmp = helper(tmp, tmp) 350*da0073e9SAndroid Build Coastguard Worker return helper(tmp, b) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 356*da0073e9SAndroid Build Coastguard Worker def test_callees_with_autocast_on(self): 357*da0073e9SAndroid Build Coastguard Worker def helper(a, b): 358*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 359*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 362*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 363*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=False): 364*da0073e9SAndroid Build Coastguard Worker return helper(a, b) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 370*da0073e9SAndroid Build Coastguard Worker def test_callees_with_autocast_off(self): 371*da0073e9SAndroid Build Coastguard Worker def helper(a, b): 372*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=False): 373*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 376*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 377*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 378*da0073e9SAndroid Build Coastguard Worker return helper(a, b) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float32) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker # scripting inside eager autocast 384*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 385*da0073e9SAndroid Build Coastguard Worker def test_eager_and_script(self): 386*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 387*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 388*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 389*da0073e9SAndroid Build Coastguard Worker for i in range(8): 390*da0073e9SAndroid Build Coastguard Worker use_autocast = (i % 2 == 0) 391*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch.float16 if use_autocast else torch.float32 392*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=use_autocast): 393*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker # traced inside scripting 397*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 398*da0073e9SAndroid Build Coastguard Worker def test_script_and_tracing(self): 399*da0073e9SAndroid Build Coastguard Worker def helper(a, b): 400*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 405*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 406*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 407*da0073e9SAndroid Build Coastguard Worker return traced(a, b) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker # traced with autocast inside scripting 413*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "autocast(False) is ignored inside traced functions") 414*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 415*da0073e9SAndroid Build Coastguard Worker def test_script_and_tracing_with_autocast(self): 416*da0073e9SAndroid Build Coastguard Worker def helper(a, b): 417*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=False): 418*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) * 2.0 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 423*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 424*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 425*da0073e9SAndroid Build Coastguard Worker return traced(a, b) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 428*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float32) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker # scripted called from traced 431*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 432*da0073e9SAndroid Build Coastguard Worker def test_tracing_and_script(self): 433*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 434*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 435*da0073e9SAndroid Build Coastguard Worker with autocast(): 436*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def traced(a, b): 439*da0073e9SAndroid Build Coastguard Worker return fn(a, b) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) 442*da0073e9SAndroid Build Coastguard Worker result = traced(self.a_fp32, self.b_fp32) 443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker # scripted called from traced with autocast 446*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "scripted called from traced TorchScript is not yet working") 447*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 448*da0073e9SAndroid Build Coastguard Worker def test_tracing_with_autocast_and_script(self): 449*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 450*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 451*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker def traced(a, b): 454*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 455*da0073e9SAndroid Build Coastguard Worker return fn(a, b) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) 458*da0073e9SAndroid Build Coastguard Worker result = traced(self.a_fp32, self.b_fp32) 459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 462*da0073e9SAndroid Build Coastguard Worker def test_script_module(self): 463*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 464*da0073e9SAndroid Build Coastguard Worker def __init__(self, N, M): 465*da0073e9SAndroid Build Coastguard Worker super().__init__() 466*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32)) 467*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(N, M).float() 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 470*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 471*da0073e9SAndroid Build Coastguard Worker output = self.weight.mv(input) 472*da0073e9SAndroid Build Coastguard Worker output = self.linear(output) 473*da0073e9SAndroid Build Coastguard Worker return output 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker scripted_module = torch.jit.script(TestModule(2, 3)).cuda() 476*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, dtype=torch.float32, device='cuda') 477*da0073e9SAndroid Build Coastguard Worker result = scripted_module(input) 478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "autocast decorators not supported") 481*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 482*da0073e9SAndroid Build Coastguard Worker def test_autocast_decorator(self): 483*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 484*da0073e9SAndroid Build Coastguard Worker @autocast(enabled=True) 485*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 486*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 487*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker # this is equivalent to running scripted functions inside autocast) 491*da0073e9SAndroid Build Coastguard Worker # (see also test_eager_and_script) 492*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 493*da0073e9SAndroid Build Coastguard Worker def test_autocast_decorator_outside_jit(self): 494*da0073e9SAndroid Build Coastguard Worker @autocast(enabled=True) 495*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 496*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 497*da0073e9SAndroid Build Coastguard Worker return torch.mm(a, b) 498*da0073e9SAndroid Build Coastguard Worker result = fn(self.a_fp32, self.b_fp32) 499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 502*da0073e9SAndroid Build Coastguard Worker def test_inplace(self): 503*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 504*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c): 505*da0073e9SAndroid Build Coastguard Worker with autocast(enabled=True): 506*da0073e9SAndroid Build Coastguard Worker x = torch.addmm(a, b, c) 507*da0073e9SAndroid Build Coastguard Worker y = torch.addmm(a, b, c, out=a) 508*da0073e9SAndroid Build Coastguard Worker z = a.addmm_(b, c) 509*da0073e9SAndroid Build Coastguard Worker return x, y, z 510*da0073e9SAndroid Build Coastguard Worker x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32) 511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.float16) 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, torch.float32) 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.dtype, torch.float32) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def _test_autocast(self, func, cast_op, *args): 516*da0073e9SAndroid Build Coastguard Worker jit_func = torch.jit.script(func) 517*da0073e9SAndroid Build Coastguard Worker o = func(*args) 518*da0073e9SAndroid Build Coastguard Worker jit_o = jit_func(*args) 519*da0073e9SAndroid Build Coastguard Worker if cast_op is not None: 520*da0073e9SAndroid Build Coastguard Worker FileCheck().check(cast_op).run(jit_func.graph_for(*args)) 521*da0073e9SAndroid Build Coastguard Worker for o0, o1 in zip(o, jit_o): 522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o0.dtype, o1.dtype) 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 525*da0073e9SAndroid Build Coastguard Worker def test_autocast_api(self): 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker def t_autocast_cpu(x, y): 528*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cpu", dtype=torch.bfloat16): 529*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker def t_autocast_cuda(x, y): 532*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", dtype=torch.half): 533*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker def t_cuda_amp_autocast(x, y): 536*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(): 537*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker def t_cpu_amp_autocast(x, y): 540*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(): 541*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device="cuda", dtype=torch.float32) 544*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, device="cuda", dtype=torch.float32) 545*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) 546*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) 547*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) 548*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "we need to provide dtype argument at this moment") 551*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 552*da0073e9SAndroid Build Coastguard Worker def test_autocast_api_not_supported(self): 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker def t_autocast_cpu(x, y): 555*da0073e9SAndroid Build Coastguard Worker # no dtype provided is not currently supported 556*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cpu"): 557*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker def t_autocast_cuda(x, y): 560*da0073e9SAndroid Build Coastguard Worker # no dtype provided is not currently supported 561*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda"): 562*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device="cuda", dtype=torch.float32) 565*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, device="cuda", dtype=torch.float32) 566*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) 567*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 570*da0073e9SAndroid Build Coastguard Worker def test_autocast_mixed_dtypes(self): 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker def t(cpu0, cpu1, cuda0, cuda1): 573*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cpu", torch.bfloat16): 574*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", torch.float16): 575*da0073e9SAndroid Build Coastguard Worker cpu_o = torch.mm(cpu0, cpu1) 576*da0073e9SAndroid Build Coastguard Worker cuda_o = torch.mm(cuda0, cuda1) 577*da0073e9SAndroid Build Coastguard Worker return cpu_o, cuda_o 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker jit_t = torch.jit.script(t) 580*da0073e9SAndroid Build Coastguard Worker cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 581*da0073e9SAndroid Build Coastguard Worker cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 582*da0073e9SAndroid Build Coastguard Worker cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 583*da0073e9SAndroid Build Coastguard Worker cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 584*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 587*da0073e9SAndroid Build Coastguard Worker def test_jit_executor_under_autocast(self): 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker def t(cpu0, cpu1, cuda0, cuda1): 590*da0073e9SAndroid Build Coastguard Worker cpu_o = torch.mm(cpu0, cpu1) 591*da0073e9SAndroid Build Coastguard Worker cuda_o = torch.mm(cuda0, cuda1) 592*da0073e9SAndroid Build Coastguard Worker return cpu_o, cuda_o 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker jit_t = torch.jit.script(t) 595*da0073e9SAndroid Build Coastguard Worker cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 596*da0073e9SAndroid Build Coastguard Worker cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) 597*da0073e9SAndroid Build Coastguard Worker cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 598*da0073e9SAndroid Build Coastguard Worker cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cpu", torch.bfloat16): 601*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", torch.float16): 602*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cpu", torch.bfloat16): 605*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", torch.float16): 608*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker # no cast op should be observed when executing outside autocast context 611*da0073e9SAndroid Build Coastguard Worker self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 614*da0073e9SAndroid Build Coastguard Worker def test_autocast_autodiff(self): 615*da0073e9SAndroid Build Coastguard Worker def t(t0, t1): 616*da0073e9SAndroid Build Coastguard Worker o = torch.mm(t0, t1) 617*da0073e9SAndroid Build Coastguard Worker return o.relu() 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker jit_t = torch.jit.script(t) 620*da0073e9SAndroid Build Coastguard Worker t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() 621*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker # run optimization 624*da0073e9SAndroid Build Coastguard Worker for i in range(5): 625*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", torch.float16): 626*da0073e9SAndroid Build Coastguard Worker jit_o = jit_t(t0, t1) 627*da0073e9SAndroid Build Coastguard Worker jit_o.sum().backward() 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker t0.grad = None 630*da0073e9SAndroid Build Coastguard Worker t1.grad = None 631*da0073e9SAndroid Build Coastguard Worker ref_t0 = t0.detach().requires_grad_() 632*da0073e9SAndroid Build Coastguard Worker ref_t1 = t1.detach().requires_grad_() 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", torch.float16): 635*da0073e9SAndroid Build Coastguard Worker o = t(ref_t0, ref_t1) 636*da0073e9SAndroid Build Coastguard Worker jit_o = jit_t(t0, t1) 637*da0073e9SAndroid Build Coastguard Worker jit_o.sum().backward() 638*da0073e9SAndroid Build Coastguard Worker o.sum().backward() 639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, jit_o) 640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t0.grad, ref_t0.grad) 641*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.grad, ref_t1.grad) 642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o.dtype, jit_o.dtype) 643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype) 644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype) 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 647*da0073e9SAndroid Build Coastguard Worker def test_jit_call_method_under_autocast(self): 648*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 649*da0073e9SAndroid Build Coastguard Worker class Iface(torch.nn.Module): 650*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y) -> torch.Tensor: 651*da0073e9SAndroid Build Coastguard Worker pass 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker class Impl(Iface): 654*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 655*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker class Thing1(torch.nn.Module): 658*da0073e9SAndroid Build Coastguard Worker impl: Iface 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 661*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(): 662*da0073e9SAndroid Build Coastguard Worker a = torch.mm(x, y) 663*da0073e9SAndroid Build Coastguard Worker b = self.impl.forward(a, x) 664*da0073e9SAndroid Build Coastguard Worker return b 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker scripted_impl = torch.jit.script(Impl()) 667*da0073e9SAndroid Build Coastguard Worker thing1 = Thing1() 668*da0073e9SAndroid Build Coastguard Worker thing1.impl = scripted_impl 669*da0073e9SAndroid Build Coastguard Worker scripted_thing1 = torch.jit.script(thing1) 670*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2]) 671*da0073e9SAndroid Build Coastguard Worker y = torch.rand([2, 2]) 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker # make sure this doesn't throw an error 674*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(): 675*da0073e9SAndroid Build Coastguard Worker ans = scripted_thing1.forward(x, y) 676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(torch.mm(x, y), x), ans) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker # sanity check: this isn't supported currently when global autocasting 679*da0073e9SAndroid Build Coastguard Worker # isn't enabled 680*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y)) 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 683*da0073e9SAndroid Build Coastguard Worker def test_jit_freeze_autocast_basic(self): 684*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 685*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 686*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(): 687*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y) 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 4), dtype=torch.float).cuda() 690*da0073e9SAndroid Build Coastguard Worker y = torch.rand((4, 5), dtype=torch.float).cuda() 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker mod = TestModule().eval() 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker # sanity check 695*da0073e9SAndroid Build Coastguard Worker self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) 698*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker # make sure that the runtime pass doesn't duplicate autocast nodes 701*da0073e9SAndroid Build Coastguard Worker frozen_mod(x, y) 702*da0073e9SAndroid Build Coastguard Worker optimized_graph = frozen_mod.graph_for(x, y) 703*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 706*da0073e9SAndroid Build Coastguard Worker def test_jit_freeze_autocast_constants(self): 707*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 708*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 709*da0073e9SAndroid Build Coastguard Worker super().__init__() 710*da0073e9SAndroid Build Coastguard Worker self.x = torch.rand((3, 4), dtype=torch.float).cuda() 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker def forward(self, y): 713*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(): 714*da0073e9SAndroid Build Coastguard Worker return torch.mm(self.x, y) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker y = torch.rand((4, 5), dtype=torch.float).cuda() 717*da0073e9SAndroid Build Coastguard Worker mod = TestModule().eval() 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) 720*da0073e9SAndroid Build Coastguard Worker # freezing should pre-cast the constant self.x to remove one autocast call 721*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker # the runtime autocasting pass will re-insert the second autocast call, 724*da0073e9SAndroid Build Coastguard Worker # but constant propagation will merge it with the constant that it's casting. 725*da0073e9SAndroid Build Coastguard Worker frozen_mod(y) 726*da0073e9SAndroid Build Coastguard Worker optimized_graph = frozen_mod.graph_for(y) 727*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_CUDA, "CPU-only test") 730*da0073e9SAndroid Build Coastguard Worker def test_jit_autocast_softmax_cpu(self): 731*da0073e9SAndroid Build Coastguard Worker def fn(x): 732*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(): 733*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.softmax(x, dim=0) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 736*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2), dtype=torch.bfloat16) 737*da0073e9SAndroid Build Coastguard Worker fn_s(x) 738*da0073e9SAndroid Build Coastguard Worker y = fn_s(x) 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.dtype == torch.bfloat16) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 743*da0073e9SAndroid Build Coastguard Worker def test_jit_autocast_softmax_gpu(self): 744*da0073e9SAndroid Build Coastguard Worker def fn(x): 745*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(): 746*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.softmax(x, dim=0) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 749*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2), dtype=torch.half).cuda() 750*da0073e9SAndroid Build Coastguard Worker fn_s(x) 751*da0073e9SAndroid Build Coastguard Worker y = fn_s(x) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.dtype == torch.float) 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker def test_ignore_amp(self): 756*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 757*da0073e9SAndroid Build Coastguard Worker def foo(x): 758*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, x) 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([10, 10], dtype=torch.float) 761*da0073e9SAndroid Build Coastguard Worker foo._set_ignore_amp(True) 762*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(): 763*da0073e9SAndroid Build Coastguard Worker foo(inp) 764*da0073e9SAndroid Build Coastguard Worker foo(inp) 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 767*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("_autocast_to_reduced").run(g) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Workerclass convbn(torch.nn.Module): 770*da0073e9SAndroid Build Coastguard Worker def __init__(self, bias_enabled=True): 771*da0073e9SAndroid Build Coastguard Worker super().__init__() 772*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled) 773*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(64) 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 776*da0073e9SAndroid Build Coastguard Worker return self.bn(self.conv(x)) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test") 779*da0073e9SAndroid Build Coastguard Workerclass TestJitTraceAutocast(JitTestCase): 780*da0073e9SAndroid Build Coastguard Worker def setUp(self): 781*da0073e9SAndroid Build Coastguard Worker super().setUp() 782*da0073e9SAndroid Build Coastguard Worker self.previous_default_dtype = torch.get_default_dtype() 783*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(torch.float32) 784*da0073e9SAndroid Build Coastguard Worker self.models = [MnistNet(), 785*da0073e9SAndroid Build Coastguard Worker convbn(bias_enabled=True), 786*da0073e9SAndroid Build Coastguard Worker convbn(bias_enabled=False)] 787*da0073e9SAndroid Build Coastguard Worker self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'), 788*da0073e9SAndroid Build Coastguard Worker torch.randn(32, 3, 224, 224, device='cpu'), 789*da0073e9SAndroid Build Coastguard Worker torch.randn(32, 3, 224, 224, device='cpu')] 790*da0073e9SAndroid Build Coastguard Worker self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 793*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass) 794*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(self.previous_default_dtype) 795*da0073e9SAndroid Build Coastguard Worker super().tearDown() 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker def test_generate_autocast_jit_trace_model(self): 798*da0073e9SAndroid Build Coastguard Worker def test_generate_autocast_jit_trace_model(model, x): 799*da0073e9SAndroid Build Coastguard Worker model.eval() 800*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): 801*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace(model, x) 802*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.freeze(traced_model) 803*da0073e9SAndroid Build Coastguard Worker for i in range(self.models.__len__()): 804*da0073e9SAndroid Build Coastguard Worker test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i]) 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker def test_nchw_autocast_jit_trace_model(self): 807*da0073e9SAndroid Build Coastguard Worker def test_nchw_autocast_jit_trace_model(model, x): 808*da0073e9SAndroid Build Coastguard Worker model.eval() 809*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): 810*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace(model, x) 811*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.freeze(traced_model) 812*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 813*da0073e9SAndroid Build Coastguard Worker y = traced_model(x.clone()) 814*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(), torch.no_grad(): 815*da0073e9SAndroid Build Coastguard Worker y2 = model(x.clone()) 816*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) 817*da0073e9SAndroid Build Coastguard Worker for i in range(self.models.__len__()): 818*da0073e9SAndroid Build Coastguard Worker test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker def test_nhwc_autocast_jit_trace_model(self): 821*da0073e9SAndroid Build Coastguard Worker def test_nhwc_autocast_jit_trace_model(model, x): 822*da0073e9SAndroid Build Coastguard Worker model = model.to(memory_format=torch.channels_last) 823*da0073e9SAndroid Build Coastguard Worker model.eval() 824*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): 825*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last)) 826*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.freeze(traced_model) 827*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 828*da0073e9SAndroid Build Coastguard Worker y = traced_model(x.clone().to(memory_format=torch.channels_last)) 829*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(), torch.no_grad(): 830*da0073e9SAndroid Build Coastguard Worker y2 = model(x.clone().to(memory_format=torch.channels_last)) 831*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) 832*da0073e9SAndroid Build Coastguard Worker for i in range(self.models.__len__()): 833*da0073e9SAndroid Build Coastguard Worker if self.inputs[i].size().__len__() == 5: 834*da0073e9SAndroid Build Coastguard Worker # NHWC 3D case not support yet 835*da0073e9SAndroid Build Coastguard Worker continue 836*da0073e9SAndroid Build Coastguard Worker test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i]) 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker def test_cat_promote(self): 839*da0073e9SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 840*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 841*da0073e9SAndroid Build Coastguard Worker return torch.cat([a, b], 0) 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker with torch.jit.fuser("none"): 844*da0073e9SAndroid Build Coastguard Worker # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs. 845*da0073e9SAndroid Build Coastguard Worker # To avoid the fusion group from TE, we will disable the fuser here. 846*da0073e9SAndroid Build Coastguard Worker for jit_freeze_or_not in [False, True]: 847*da0073e9SAndroid Build Coastguard Worker test_model = TestModel().eval() 848*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad(): 849*da0073e9SAndroid Build Coastguard Worker a = torch.rand(24, 128, 128) 850*da0073e9SAndroid Build Coastguard Worker b = torch.rand(24, 128, 128, dtype=torch.bfloat16) 851*da0073e9SAndroid Build Coastguard Worker c = test_model(a, b) 852*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test_model, (a, b)) 853*da0073e9SAndroid Build Coastguard Worker if jit_freeze_or_not: 854*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.freeze(traced) 855*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 856*da0073e9SAndroid Build Coastguard Worker c2 = traced(a, b) 857*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c.dtype, torch.float32) 858*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c2.dtype, torch.float32) 859*da0073e9SAndroid Build Coastguard Worker traced_graph = traced.graph_for(a, b) 860*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes())) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker def test_script_autocast_cpu(self): 863*da0073e9SAndroid Build Coastguard Worker def fn(x): 864*da0073e9SAndroid Build Coastguard Worker if torch.is_autocast_cpu_enabled(): 865*da0073e9SAndroid Build Coastguard Worker return x.relu() 866*da0073e9SAndroid Build Coastguard Worker else: 867*da0073e9SAndroid Build Coastguard Worker return x.sin() 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker x = torch.rand((4, 4)) - 0.5 872*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(): 873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_s(x), fn(x)) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(enabled=True): 876*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_s(x), fn(x)) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes())) 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "No cuda") 881*da0073e9SAndroid Build Coastguard Worker def test_script_autocast_cuda(self): 882*da0073e9SAndroid Build Coastguard Worker def fn(x): 883*da0073e9SAndroid Build Coastguard Worker if torch.is_autocast_enabled(): 884*da0073e9SAndroid Build Coastguard Worker return x.relu() 885*da0073e9SAndroid Build Coastguard Worker else: 886*da0073e9SAndroid Build Coastguard Worker return x.sin() 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker x = torch.rand((4, 4)) - 0.5 891*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(): 892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_s(x), fn(x)) 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(enabled=True): 895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_s(x), fn(x)) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes())) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker def test_scripted_aliasing(self): 901*da0073e9SAndroid Build Coastguard Worker # torch.is_autocast_enabled should not be able to move inside of the autocast context. 902*da0073e9SAndroid Build Coastguard Worker def fn(x): 903*da0073e9SAndroid Build Coastguard Worker if torch.is_autocast_enabled(): 904*da0073e9SAndroid Build Coastguard Worker y = True 905*da0073e9SAndroid Build Coastguard Worker else: 906*da0073e9SAndroid Build Coastguard Worker y = False 907*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(enabled=True): 908*da0073e9SAndroid Build Coastguard Worker z = x.relu() 909*da0073e9SAndroid Build Coastguard Worker return y, z 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 912*da0073e9SAndroid Build Coastguard Worker graph = fn_s.graph 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker aliasdb = graph.alias_db() 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled") 917*da0073e9SAndroid Build Coastguard Worker enter_nodes = graph.findAllNodes("prim::Enter") 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(is_enabled_nodes), 1) 920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(enter_nodes), 1) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0])) 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker def test_script_autocast_enable_and_check(self): 926*da0073e9SAndroid Build Coastguard Worker def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]: 927*da0073e9SAndroid Build Coastguard Worker b1 = torch.is_autocast_cpu_enabled() 928*da0073e9SAndroid Build Coastguard Worker v1 = torch.mm(x, y) 929*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(enabled=True): 930*da0073e9SAndroid Build Coastguard Worker b2 = torch.is_autocast_cpu_enabled() 931*da0073e9SAndroid Build Coastguard Worker v2 = torch.mm(x, y) 932*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(enabled=False): 933*da0073e9SAndroid Build Coastguard Worker b3 = torch.is_autocast_cpu_enabled() 934*da0073e9SAndroid Build Coastguard Worker v3 = torch.mm(x, y) 935*da0073e9SAndroid Build Coastguard Worker return (v1, b1, v2, b2, v3, b3) 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float 938*da0073e9SAndroid Build Coastguard Worker def check_fn_results(arr): 939*da0073e9SAndroid Build Coastguard Worker [v1, b1, v2, b2, v3, b3] = arr 940*da0073e9SAndroid Build Coastguard Worker self.assertTrue((v1.dtype == torch.float) != b1) 941*da0073e9SAndroid Build Coastguard Worker self.assertTrue((v2.dtype == torch.float) != b2) 942*da0073e9SAndroid Build Coastguard Worker self.assertTrue((v3.dtype == torch.float) != b3) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2), dtype=torch.float) 945*da0073e9SAndroid Build Coastguard Worker y = torch.rand((2, 2), dtype=torch.float) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(enabled=False): 950*da0073e9SAndroid Build Coastguard Worker check_fn_results(fn(x, y)) 951*da0073e9SAndroid Build Coastguard Worker check_fn_results(fn_s(x, y)) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(enabled=True): 954*da0073e9SAndroid Build Coastguard Worker check_fn_results(fn(x, y)) 955*da0073e9SAndroid Build Coastguard Worker check_fn_results(fn_s(x, y)) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 959*da0073e9SAndroid Build Coastguard Worker run_tests() 960