1# Owner(s): ["module: unknown"] 2 3import unittest 4 5import torch 6from torch.testing._internal.autocast_test_lists import ( 7 AutocastCPUTestLists, 8 TestAutocast, 9) 10from torch.testing._internal.common_utils import ( 11 IS_WINDOWS, 12 run_tests, 13 skipIfTorchDynamo, 14 TestCase, 15) 16from torch.utils._python_dispatch import TorchDispatchMode 17 18 19class TestAutocastCPU(TestAutocast): 20 def setUp(self): 21 super().setUp() 22 self.autocast_lists = AutocastCPUTestLists(torch.device("cpu")) 23 24 def tearDown(self): 25 del self.autocast_lists 26 super().tearDown() 27 28 @skipIfTorchDynamo() 29 def test_autocast_torch_expect_builtin_promote(self): 30 for ( 31 op, 32 args1, 33 args2, 34 out_type, 35 ) in self.autocast_lists.torch_expect_builtin_promote: 36 self._run_autocast_outofplace( 37 op, args1, torch.float32, device="cpu", out_type=out_type 38 ) 39 self._run_autocast_outofplace( 40 op, 41 args2, 42 torch.float32, 43 device="cpu", 44 out_type=out_type, 45 amp_dtype=torch.float16, 46 ) 47 48 @skipIfTorchDynamo() 49 def test_autocast_methods_expect_builtin_promote(self): 50 for ( 51 op, 52 args1, 53 args2, 54 out_type, 55 ) in self.autocast_lists.methods_expect_builtin_promote: 56 self._run_autocast_outofplace( 57 op, args1, torch.float32, device="cpu", module=None, out_type=out_type 58 ) 59 self._run_autocast_outofplace( 60 op, 61 args2, 62 torch.float32, 63 device="cpu", 64 module=None, 65 out_type=out_type, 66 amp_dtype=torch.float16, 67 ) 68 69 @skipIfTorchDynamo() 70 def test_autocast_torch_16(self): 71 for op_with_args in self.autocast_lists.torch_16: 72 op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 73 self._run_autocast_outofplace( 74 op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs 75 ) 76 self._run_autocast_outofplace( 77 op, 78 args, 79 torch.float16, 80 device="cpu", 81 add_kwargs=maybe_kwargs, 82 amp_dtype=torch.float16, 83 ) 84 85 @skipIfTorchDynamo() 86 def test_autocast_nn_16(self): 87 for op_with_args in self.autocast_lists.nn_16: 88 op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 89 self._run_autocast_outofplace( 90 op, 91 args, 92 torch.bfloat16, 93 device="cpu", 94 module=torch._C._nn, 95 add_kwargs=maybe_kwargs, 96 ) 97 self._run_autocast_outofplace( 98 op, 99 args, 100 torch.float16, 101 device="cpu", 102 module=torch._C._nn, 103 add_kwargs=maybe_kwargs, 104 amp_dtype=torch.float16, 105 ) 106 107 @skipIfTorchDynamo() 108 def test_autocast_torch_fp32(self): 109 for op_with_args in self.autocast_lists.torch_fp32: 110 op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 111 self._run_autocast_outofplace( 112 op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs 113 ) 114 self._run_autocast_outofplace( 115 op, 116 args, 117 torch.float32, 118 device="cpu", 119 add_kwargs=maybe_kwargs, 120 amp_dtype=torch.float16, 121 ) 122 123 @skipIfTorchDynamo() 124 def test_autocast_nn_fp32(self): 125 for op_with_args in self.autocast_lists.nn_fp32: 126 op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 127 self._run_autocast_outofplace( 128 op, 129 args, 130 torch.float32, 131 device="cpu", 132 module=torch._C._nn, 133 add_kwargs=maybe_kwargs, 134 ) 135 self._run_autocast_outofplace( 136 op, 137 args, 138 torch.float32, 139 device="cpu", 140 module=torch._C._nn, 141 add_kwargs=maybe_kwargs, 142 amp_dtype=torch.float16, 143 ) 144 145 @skipIfTorchDynamo() 146 def test_autocast_torch_need_autocast_promote(self): 147 for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote: 148 self._run_autocast_outofplace(op, args1, torch.float32, device="cpu") 149 self._run_autocast_outofplace( 150 op, args2, torch.float32, device="cpu", amp_dtype=torch.float16 151 ) 152 153 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 154 def test_autocast_rnn(self): 155 if ( 156 torch.backends.mkldnn.is_available() 157 and torch.ops.mkldnn._is_mkldnn_bf16_supported() 158 ): 159 x = torch.randn(1, 2, 1) 160 hx = torch.randn(2, 2, 1) 161 cx = torch.randn(2, 2, 1) 162 163 m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16) 164 165 # Raise ValueError when autocast is not enabled 166 with self.assertRaisesRegex(ValueError, "input must have the type"): 167 m(x, (hx, cx)) 168 169 # Should be able to run the below case with autocast 170 with torch.amp.autocast(device_type="cpu"): 171 m(x, (hx, cx)) 172 173 def test_autocast_disabled_with_fp32_dtype(self): 174 with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False): 175 _ = torch.ones(10) 176 177 def test_generic_autocast(self): 178 for op_with_args in self.autocast_lists.torch_16: 179 op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 180 with torch.amp.autocast(device_type="cpu"): 181 generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) 182 with torch.amp.autocast(device_type="cpu"): 183 cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) 184 self.assertEqual(generic_autocast_output, cpu_autocast_output) 185 186 def test_cpu_autocast_deprecated_warning(self): 187 with self.assertWarnsRegex( 188 FutureWarning, 189 r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.", 190 ): 191 with torch.cpu.amp.autocast(): 192 _ = torch.ones(10) 193 194 195class CustomLinear(torch.autograd.Function): 196 @staticmethod 197 def forward(ctx, x, w_t): 198 ctx.save_for_backward(x, w_t) 199 return torch.nn.functional.linear(x, w_t) 200 201 @staticmethod 202 def backward(ctx, grad_output): 203 x, w_t = ctx.saved_tensors 204 with torch.autocast(device_type="cuda"): 205 dL_dX = torch.matmul(grad_output, w_t) 206 dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) 207 return dL_dX, dL_dW 208 209 210class WeightDTypeCastCounterMode(TorchDispatchMode): 211 def __init__(self, weight): 212 super().__init__() 213 self.dtype_cast_counter = 0 214 self.weight = weight 215 216 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 217 if ( 218 func is torch.ops.aten._to_copy.default 219 and args[0] is self.weight 220 and kwargs["dtype"] is torch.float16 221 ): 222 self.dtype_cast_counter += 1 223 return func(*args, **kwargs) 224 225 def __enter__(self): 226 self.old_clear_cache = torch.clear_autocast_cache 227 torch.clear_autocast_cache = lambda: None 228 return super().__enter__() 229 230 def __exit__(self, exc_type, exc_val, exc_tb): 231 torch.clear_autocast_cache = self.old_clear_cache 232 return super().__exit__(exc_type, exc_val, exc_tb) 233 234 235@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 236class TestAutocastGPU(TestCase): 237 def test_cast_cache_is_global(self): 238 """ 239 Verifies that the autocast cache is global. This is done by 240 mocking out cache clearing at the end of the forward pass, 241 running forward+backward with an explicit call to autocast in the 242 backward, and verifying that the weight only get cast to float16 once. 243 """ 244 245 data = torch.randn(2, 3).cuda() 246 weight = torch.nn.Parameter(torch.randn(4, 3).cuda()) 247 248 with WeightDTypeCastCounterMode(weight) as mode: 249 with torch.autocast(device_type="cuda"): 250 output = CustomLinear.apply(data, weight) 251 s = output.sum() 252 s.backward() 253 254 self.assertEqual(mode.dtype_cast_counter, 1) 255 256 def test_cache_disabled(self): 257 data = torch.randn(2, 3).cuda() 258 weight = torch.nn.Parameter(torch.randn(4, 3).cuda()) 259 260 try: 261 torch._C._set_cached_tensors_enabled(True) 262 torch._C._add_cached_tensor(weight) 263 264 with WeightDTypeCastCounterMode(weight) as mode: 265 with torch.autocast(device_type="cuda"): 266 output = CustomLinear.apply(data, weight) 267 s = output.sum() 268 s.backward() 269 270 # we should not have cached the conversion of the weight 271 self.assertEqual(mode.dtype_cast_counter, 2) 272 273 finally: 274 torch._C._set_cached_tensors_enabled(False) 275 276 # index_put under AMP follows a cast policy called "promote", 277 # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230 278 # That means: 279 # (1) double precision is ignored, 280 # (2) if any argument is float, then all arguments are promoted to float, 281 # (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype. 282 # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution, 283 # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward 284 # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied. 285 # For more info see https://github.com/pytorch/pytorch/issues/132715. 286 def test_autocast_prioritize(self): 287 device = "cuda" 288 dtype = torch.bfloat16 289 290 with torch.autocast(device_type=device, enabled=True, dtype=dtype): 291 t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True) 292 index = torch.randint( 293 low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device 294 ) 295 val = torch.randn(1, dtype=dtype, device=device) 296 297 res = torch.index_put(t, [index], val) 298 299 loss = res.mean() 300 loss.backward() 301 302 303@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") 304class TestAutocastMPS(TestCase): 305 def test_cast_cache_is_global(self): 306 class CustomLinear(torch.autograd.Function): 307 @staticmethod 308 def forward(ctx, x, w_t): 309 ctx.save_for_backward(x, w_t) 310 return torch.nn.functional.linear(x, w_t) 311 312 @staticmethod 313 def backward(ctx, grad_output): 314 x, w_t = ctx.saved_tensors 315 with torch.autocast(device_type="mps"): 316 dL_dX = torch.matmul(grad_output, w_t) 317 dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) 318 return dL_dX, dL_dW 319 320 data = torch.randn(2, 3).to("mps") 321 weight = torch.nn.Parameter(torch.randn(4, 3).to("mps")) 322 weight_dtype_cast_counter = 0 323 324 class WeightDTypeCastCounterMode(TorchDispatchMode): 325 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 326 if ( 327 func is torch.ops.aten._to_copy.default 328 and args[0] is weight 329 and kwargs["dtype"] is torch.float16 330 ): 331 nonlocal weight_dtype_cast_counter 332 weight_dtype_cast_counter += 1 333 return func(*args, **kwargs) 334 335 def __enter__(self): 336 # self.old_clear_cache = torch.clear_autocast_cache 337 # torch.clear_autocast_cache = lambda: None 338 return super().__enter__() 339 340 def __exit__(self, exc_type, exc_val, exc_tb): 341 # torch.clear_autocast_cache = self.old_clear_cache 342 return super().__exit__(exc_type, exc_val, exc_tb) 343 344 with WeightDTypeCastCounterMode(): 345 with torch.autocast(device_type="mps"): 346 output = CustomLinear.apply(data, weight) 347 s = output.sum() 348 s.backward() 349 self.assertEqual(weight_dtype_cast_counter, 2) 350 351 352class TestTorchAutocast(TestCase): 353 def test_autocast_fast_dtype(self): 354 gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda") 355 cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu") 356 self.assertEqual(gpu_fast_dtype, torch.half) 357 self.assertEqual(cpu_fast_dtype, torch.bfloat16) 358 359 def test_invalid_device(self): 360 dev = "not a real device" 361 msg = f"Invalid device string: '{dev}'" 362 with self.assertRaisesRegex(RuntimeError, msg): 363 with torch.autocast(device_type=dev): 364 _ = torch.tensor(1) 365 with self.assertRaisesRegex(RuntimeError, msg): 366 assert torch.amp.is_autocast_available(device_type=dev) 367 368 def test_non_string_device(self): 369 """Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string""" 370 dev = torch.device("cpu") 371 msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`" 372 with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg): 373 torch.autocast(device_type=dev) 374 375 376if __name__ == "__main__": 377 run_tests() 378