1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.autocast_test_lists import ( 7*da0073e9SAndroid Build Coastguard Worker AutocastCPUTestLists, 8*da0073e9SAndroid Build Coastguard Worker TestAutocast, 9*da0073e9SAndroid Build Coastguard Worker) 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 11*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 12*da0073e9SAndroid Build Coastguard Worker run_tests, 13*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 14*da0073e9SAndroid Build Coastguard Worker TestCase, 15*da0073e9SAndroid Build Coastguard Worker) 16*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass TestAutocastCPU(TestAutocast): 20*da0073e9SAndroid Build Coastguard Worker def setUp(self): 21*da0073e9SAndroid Build Coastguard Worker super().setUp() 22*da0073e9SAndroid Build Coastguard Worker self.autocast_lists = AutocastCPUTestLists(torch.device("cpu")) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 25*da0073e9SAndroid Build Coastguard Worker del self.autocast_lists 26*da0073e9SAndroid Build Coastguard Worker super().tearDown() 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 29*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_expect_builtin_promote(self): 30*da0073e9SAndroid Build Coastguard Worker for ( 31*da0073e9SAndroid Build Coastguard Worker op, 32*da0073e9SAndroid Build Coastguard Worker args1, 33*da0073e9SAndroid Build Coastguard Worker args2, 34*da0073e9SAndroid Build Coastguard Worker out_type, 35*da0073e9SAndroid Build Coastguard Worker ) in self.autocast_lists.torch_expect_builtin_promote: 36*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 37*da0073e9SAndroid Build Coastguard Worker op, args1, torch.float32, device="cpu", out_type=out_type 38*da0073e9SAndroid Build Coastguard Worker ) 39*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 40*da0073e9SAndroid Build Coastguard Worker op, 41*da0073e9SAndroid Build Coastguard Worker args2, 42*da0073e9SAndroid Build Coastguard Worker torch.float32, 43*da0073e9SAndroid Build Coastguard Worker device="cpu", 44*da0073e9SAndroid Build Coastguard Worker out_type=out_type, 45*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 46*da0073e9SAndroid Build Coastguard Worker ) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 49*da0073e9SAndroid Build Coastguard Worker def test_autocast_methods_expect_builtin_promote(self): 50*da0073e9SAndroid Build Coastguard Worker for ( 51*da0073e9SAndroid Build Coastguard Worker op, 52*da0073e9SAndroid Build Coastguard Worker args1, 53*da0073e9SAndroid Build Coastguard Worker args2, 54*da0073e9SAndroid Build Coastguard Worker out_type, 55*da0073e9SAndroid Build Coastguard Worker ) in self.autocast_lists.methods_expect_builtin_promote: 56*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 57*da0073e9SAndroid Build Coastguard Worker op, args1, torch.float32, device="cpu", module=None, out_type=out_type 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 60*da0073e9SAndroid Build Coastguard Worker op, 61*da0073e9SAndroid Build Coastguard Worker args2, 62*da0073e9SAndroid Build Coastguard Worker torch.float32, 63*da0073e9SAndroid Build Coastguard Worker device="cpu", 64*da0073e9SAndroid Build Coastguard Worker module=None, 65*da0073e9SAndroid Build Coastguard Worker out_type=out_type, 66*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 67*da0073e9SAndroid Build Coastguard Worker ) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 70*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_16(self): 71*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_16: 72*da0073e9SAndroid Build Coastguard Worker op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 73*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 74*da0073e9SAndroid Build Coastguard Worker op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 77*da0073e9SAndroid Build Coastguard Worker op, 78*da0073e9SAndroid Build Coastguard Worker args, 79*da0073e9SAndroid Build Coastguard Worker torch.float16, 80*da0073e9SAndroid Build Coastguard Worker device="cpu", 81*da0073e9SAndroid Build Coastguard Worker add_kwargs=maybe_kwargs, 82*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 86*da0073e9SAndroid Build Coastguard Worker def test_autocast_nn_16(self): 87*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.nn_16: 88*da0073e9SAndroid Build Coastguard Worker op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 89*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 90*da0073e9SAndroid Build Coastguard Worker op, 91*da0073e9SAndroid Build Coastguard Worker args, 92*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 93*da0073e9SAndroid Build Coastguard Worker device="cpu", 94*da0073e9SAndroid Build Coastguard Worker module=torch._C._nn, 95*da0073e9SAndroid Build Coastguard Worker add_kwargs=maybe_kwargs, 96*da0073e9SAndroid Build Coastguard Worker ) 97*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 98*da0073e9SAndroid Build Coastguard Worker op, 99*da0073e9SAndroid Build Coastguard Worker args, 100*da0073e9SAndroid Build Coastguard Worker torch.float16, 101*da0073e9SAndroid Build Coastguard Worker device="cpu", 102*da0073e9SAndroid Build Coastguard Worker module=torch._C._nn, 103*da0073e9SAndroid Build Coastguard Worker add_kwargs=maybe_kwargs, 104*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 105*da0073e9SAndroid Build Coastguard Worker ) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 108*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_fp32(self): 109*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_fp32: 110*da0073e9SAndroid Build Coastguard Worker op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 111*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 112*da0073e9SAndroid Build Coastguard Worker op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 115*da0073e9SAndroid Build Coastguard Worker op, 116*da0073e9SAndroid Build Coastguard Worker args, 117*da0073e9SAndroid Build Coastguard Worker torch.float32, 118*da0073e9SAndroid Build Coastguard Worker device="cpu", 119*da0073e9SAndroid Build Coastguard Worker add_kwargs=maybe_kwargs, 120*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 121*da0073e9SAndroid Build Coastguard Worker ) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 124*da0073e9SAndroid Build Coastguard Worker def test_autocast_nn_fp32(self): 125*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.nn_fp32: 126*da0073e9SAndroid Build Coastguard Worker op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 127*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 128*da0073e9SAndroid Build Coastguard Worker op, 129*da0073e9SAndroid Build Coastguard Worker args, 130*da0073e9SAndroid Build Coastguard Worker torch.float32, 131*da0073e9SAndroid Build Coastguard Worker device="cpu", 132*da0073e9SAndroid Build Coastguard Worker module=torch._C._nn, 133*da0073e9SAndroid Build Coastguard Worker add_kwargs=maybe_kwargs, 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 136*da0073e9SAndroid Build Coastguard Worker op, 137*da0073e9SAndroid Build Coastguard Worker args, 138*da0073e9SAndroid Build Coastguard Worker torch.float32, 139*da0073e9SAndroid Build Coastguard Worker device="cpu", 140*da0073e9SAndroid Build Coastguard Worker module=torch._C._nn, 141*da0073e9SAndroid Build Coastguard Worker add_kwargs=maybe_kwargs, 142*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 146*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_need_autocast_promote(self): 147*da0073e9SAndroid Build Coastguard Worker for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote: 148*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace(op, args1, torch.float32, device="cpu") 149*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 150*da0073e9SAndroid Build Coastguard Worker op, args2, torch.float32, device="cpu", amp_dtype=torch.float16 151*da0073e9SAndroid Build Coastguard Worker ) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 154*da0073e9SAndroid Build Coastguard Worker def test_autocast_rnn(self): 155*da0073e9SAndroid Build Coastguard Worker if ( 156*da0073e9SAndroid Build Coastguard Worker torch.backends.mkldnn.is_available() 157*da0073e9SAndroid Build Coastguard Worker and torch.ops.mkldnn._is_mkldnn_bf16_supported() 158*da0073e9SAndroid Build Coastguard Worker ): 159*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 2, 1) 160*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(2, 2, 1) 161*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(2, 2, 1) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker # Raise ValueError when autocast is not enabled 166*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "input must have the type"): 167*da0073e9SAndroid Build Coastguard Worker m(x, (hx, cx)) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker # Should be able to run the below case with autocast 170*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast(device_type="cpu"): 171*da0073e9SAndroid Build Coastguard Worker m(x, (hx, cx)) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker def test_autocast_disabled_with_fp32_dtype(self): 174*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False): 175*da0073e9SAndroid Build Coastguard Worker _ = torch.ones(10) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def test_generic_autocast(self): 178*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_16: 179*da0073e9SAndroid Build Coastguard Worker op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 180*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast(device_type="cpu"): 181*da0073e9SAndroid Build Coastguard Worker generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) 182*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast(device_type="cpu"): 183*da0073e9SAndroid Build Coastguard Worker cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) 184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(generic_autocast_output, cpu_autocast_output) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def test_cpu_autocast_deprecated_warning(self): 187*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 188*da0073e9SAndroid Build Coastguard Worker FutureWarning, 189*da0073e9SAndroid Build Coastguard Worker r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.", 190*da0073e9SAndroid Build Coastguard Worker ): 191*da0073e9SAndroid Build Coastguard Worker with torch.cpu.amp.autocast(): 192*da0073e9SAndroid Build Coastguard Worker _ = torch.ones(10) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Workerclass CustomLinear(torch.autograd.Function): 196*da0073e9SAndroid Build Coastguard Worker @staticmethod 197*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, w_t): 198*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, w_t) 199*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.linear(x, w_t) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker @staticmethod 202*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 203*da0073e9SAndroid Build Coastguard Worker x, w_t = ctx.saved_tensors 204*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cuda"): 205*da0073e9SAndroid Build Coastguard Worker dL_dX = torch.matmul(grad_output, w_t) 206*da0073e9SAndroid Build Coastguard Worker dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) 207*da0073e9SAndroid Build Coastguard Worker return dL_dX, dL_dW 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Workerclass WeightDTypeCastCounterMode(TorchDispatchMode): 211*da0073e9SAndroid Build Coastguard Worker def __init__(self, weight): 212*da0073e9SAndroid Build Coastguard Worker super().__init__() 213*da0073e9SAndroid Build Coastguard Worker self.dtype_cast_counter = 0 214*da0073e9SAndroid Build Coastguard Worker self.weight = weight 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 217*da0073e9SAndroid Build Coastguard Worker if ( 218*da0073e9SAndroid Build Coastguard Worker func is torch.ops.aten._to_copy.default 219*da0073e9SAndroid Build Coastguard Worker and args[0] is self.weight 220*da0073e9SAndroid Build Coastguard Worker and kwargs["dtype"] is torch.float16 221*da0073e9SAndroid Build Coastguard Worker ): 222*da0073e9SAndroid Build Coastguard Worker self.dtype_cast_counter += 1 223*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 226*da0073e9SAndroid Build Coastguard Worker self.old_clear_cache = torch.clear_autocast_cache 227*da0073e9SAndroid Build Coastguard Worker torch.clear_autocast_cache = lambda: None 228*da0073e9SAndroid Build Coastguard Worker return super().__enter__() 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 231*da0073e9SAndroid Build Coastguard Worker torch.clear_autocast_cache = self.old_clear_cache 232*da0073e9SAndroid Build Coastguard Worker return super().__exit__(exc_type, exc_val, exc_tb) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 236*da0073e9SAndroid Build Coastguard Workerclass TestAutocastGPU(TestCase): 237*da0073e9SAndroid Build Coastguard Worker def test_cast_cache_is_global(self): 238*da0073e9SAndroid Build Coastguard Worker """ 239*da0073e9SAndroid Build Coastguard Worker Verifies that the autocast cache is global. This is done by 240*da0073e9SAndroid Build Coastguard Worker mocking out cache clearing at the end of the forward pass, 241*da0073e9SAndroid Build Coastguard Worker running forward+backward with an explicit call to autocast in the 242*da0073e9SAndroid Build Coastguard Worker backward, and verifying that the weight only get cast to float16 once. 243*da0073e9SAndroid Build Coastguard Worker """ 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker data = torch.randn(2, 3).cuda() 246*da0073e9SAndroid Build Coastguard Worker weight = torch.nn.Parameter(torch.randn(4, 3).cuda()) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker with WeightDTypeCastCounterMode(weight) as mode: 249*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cuda"): 250*da0073e9SAndroid Build Coastguard Worker output = CustomLinear.apply(data, weight) 251*da0073e9SAndroid Build Coastguard Worker s = output.sum() 252*da0073e9SAndroid Build Coastguard Worker s.backward() 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mode.dtype_cast_counter, 1) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def test_cache_disabled(self): 257*da0073e9SAndroid Build Coastguard Worker data = torch.randn(2, 3).cuda() 258*da0073e9SAndroid Build Coastguard Worker weight = torch.nn.Parameter(torch.randn(4, 3).cuda()) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker try: 261*da0073e9SAndroid Build Coastguard Worker torch._C._set_cached_tensors_enabled(True) 262*da0073e9SAndroid Build Coastguard Worker torch._C._add_cached_tensor(weight) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker with WeightDTypeCastCounterMode(weight) as mode: 265*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cuda"): 266*da0073e9SAndroid Build Coastguard Worker output = CustomLinear.apply(data, weight) 267*da0073e9SAndroid Build Coastguard Worker s = output.sum() 268*da0073e9SAndroid Build Coastguard Worker s.backward() 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker # we should not have cached the conversion of the weight 271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mode.dtype_cast_counter, 2) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker finally: 274*da0073e9SAndroid Build Coastguard Worker torch._C._set_cached_tensors_enabled(False) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker # index_put under AMP follows a cast policy called "promote", 277*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230 278*da0073e9SAndroid Build Coastguard Worker # That means: 279*da0073e9SAndroid Build Coastguard Worker # (1) double precision is ignored, 280*da0073e9SAndroid Build Coastguard Worker # (2) if any argument is float, then all arguments are promoted to float, 281*da0073e9SAndroid Build Coastguard Worker # (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype. 282*da0073e9SAndroid Build Coastguard Worker # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution, 283*da0073e9SAndroid Build Coastguard Worker # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward 284*da0073e9SAndroid Build Coastguard Worker # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied. 285*da0073e9SAndroid Build Coastguard Worker # For more info see https://github.com/pytorch/pytorch/issues/132715. 286*da0073e9SAndroid Build Coastguard Worker def test_autocast_prioritize(self): 287*da0073e9SAndroid Build Coastguard Worker device = "cuda" 288*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=device, enabled=True, dtype=dtype): 291*da0073e9SAndroid Build Coastguard Worker t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True) 292*da0073e9SAndroid Build Coastguard Worker index = torch.randint( 293*da0073e9SAndroid Build Coastguard Worker low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device 294*da0073e9SAndroid Build Coastguard Worker ) 295*da0073e9SAndroid Build Coastguard Worker val = torch.randn(1, dtype=dtype, device=device) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker res = torch.index_put(t, [index], val) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker loss = res.mean() 300*da0073e9SAndroid Build Coastguard Worker loss.backward() 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") 304*da0073e9SAndroid Build Coastguard Workerclass TestAutocastMPS(TestCase): 305*da0073e9SAndroid Build Coastguard Worker def test_cast_cache_is_global(self): 306*da0073e9SAndroid Build Coastguard Worker class CustomLinear(torch.autograd.Function): 307*da0073e9SAndroid Build Coastguard Worker @staticmethod 308*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, w_t): 309*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, w_t) 310*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.linear(x, w_t) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker @staticmethod 313*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 314*da0073e9SAndroid Build Coastguard Worker x, w_t = ctx.saved_tensors 315*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="mps"): 316*da0073e9SAndroid Build Coastguard Worker dL_dX = torch.matmul(grad_output, w_t) 317*da0073e9SAndroid Build Coastguard Worker dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) 318*da0073e9SAndroid Build Coastguard Worker return dL_dX, dL_dW 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker data = torch.randn(2, 3).to("mps") 321*da0073e9SAndroid Build Coastguard Worker weight = torch.nn.Parameter(torch.randn(4, 3).to("mps")) 322*da0073e9SAndroid Build Coastguard Worker weight_dtype_cast_counter = 0 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker class WeightDTypeCastCounterMode(TorchDispatchMode): 325*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 326*da0073e9SAndroid Build Coastguard Worker if ( 327*da0073e9SAndroid Build Coastguard Worker func is torch.ops.aten._to_copy.default 328*da0073e9SAndroid Build Coastguard Worker and args[0] is weight 329*da0073e9SAndroid Build Coastguard Worker and kwargs["dtype"] is torch.float16 330*da0073e9SAndroid Build Coastguard Worker ): 331*da0073e9SAndroid Build Coastguard Worker nonlocal weight_dtype_cast_counter 332*da0073e9SAndroid Build Coastguard Worker weight_dtype_cast_counter += 1 333*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 336*da0073e9SAndroid Build Coastguard Worker # self.old_clear_cache = torch.clear_autocast_cache 337*da0073e9SAndroid Build Coastguard Worker # torch.clear_autocast_cache = lambda: None 338*da0073e9SAndroid Build Coastguard Worker return super().__enter__() 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 341*da0073e9SAndroid Build Coastguard Worker # torch.clear_autocast_cache = self.old_clear_cache 342*da0073e9SAndroid Build Coastguard Worker return super().__exit__(exc_type, exc_val, exc_tb) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker with WeightDTypeCastCounterMode(): 345*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="mps"): 346*da0073e9SAndroid Build Coastguard Worker output = CustomLinear.apply(data, weight) 347*da0073e9SAndroid Build Coastguard Worker s = output.sum() 348*da0073e9SAndroid Build Coastguard Worker s.backward() 349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight_dtype_cast_counter, 2) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Workerclass TestTorchAutocast(TestCase): 353*da0073e9SAndroid Build Coastguard Worker def test_autocast_fast_dtype(self): 354*da0073e9SAndroid Build Coastguard Worker gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda") 355*da0073e9SAndroid Build Coastguard Worker cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu") 356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gpu_fast_dtype, torch.half) 357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_fast_dtype, torch.bfloat16) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def test_invalid_device(self): 360*da0073e9SAndroid Build Coastguard Worker dev = "not a real device" 361*da0073e9SAndroid Build Coastguard Worker msg = f"Invalid device string: '{dev}'" 362*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 363*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=dev): 364*da0073e9SAndroid Build Coastguard Worker _ = torch.tensor(1) 365*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 366*da0073e9SAndroid Build Coastguard Worker assert torch.amp.is_autocast_available(device_type=dev) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def test_non_string_device(self): 369*da0073e9SAndroid Build Coastguard Worker """Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string""" 370*da0073e9SAndroid Build Coastguard Worker dev = torch.device("cpu") 371*da0073e9SAndroid Build Coastguard Worker msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`" 372*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg): 373*da0073e9SAndroid Build Coastguard Worker torch.autocast(device_type=dev) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 377*da0073e9SAndroid Build Coastguard Worker run_tests() 378