1# Owner(s): ["module: nn"] 2 3from itertools import chain, product 4from inspect import signature, isgenerator 5from copy import deepcopy 6import tempfile 7from operator import methodcaller 8 9import torch 10 11from torch._subclasses.meta_utils import assert_metadata_eq 12from torch.testing._internal.common_cuda import with_tf32_off 13from torch.testing._internal.common_device_type import ( 14 instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta) 15from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode 16from torch.testing._internal.common_utils import ( 17 TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, 18 gradgradcheck, parametrize, wrapSwapTensorsTest) 19from unittest.mock import patch, call 20 21 22class TestModule(TestCase): 23 _do_cuda_memory_leak_check = True 24 _do_cuda_non_default_stream = True 25 precision = 1e-5 26 rel_tol = 1e-5 27 28 def _assert_module_parameters_and_buffer_are(self, module, device, dtype): 29 # Check device placement and dtype for created parameters and buffers. 30 # Only verify floating point dtypes since that's what the kwarg or methods 31 # such as `float()` applies to. 32 if not isinstance(device, torch.device): 33 device = torch.device(device) 34 35 def _check_module(items, name, device=device, dtype=dtype): 36 for item_name, item in items: 37 self.assertEqual( 38 item.device, device, 39 f'{name} {item_name} is on device {item.device} instead of the expected device {device}') 40 if item.dtype.is_floating_point: 41 self.assertEqual( 42 item.dtype, dtype, 43 f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}') 44 _check_module(module.named_parameters(), "Parameter") 45 _check_module(module.named_buffers(), "Buffer") 46 47 @modules(module_db) 48 def test_forward(self, device, dtype, module_info, training): 49 module_cls = module_info.module_cls 50 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 51 requires_grad=False, training=training) 52 dtype_to_method_caller = { 53 torch.float32: methodcaller("float"), 54 torch.float64: methodcaller("double"), 55 } 56 for module_input in module_inputs: 57 if module_input.forward_input is None: 58 continue 59 60 with freeze_rng_state(): 61 # === Instantiate the module. === 62 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 63 m = module_cls(*args, **kwargs) 64 m.to(device).to(dtype) 65 m.train(training) 66 67 # === Do forward pass. === 68 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 69 outputs = m(*args, **kwargs) 70 71 # === Compare outputs to a reference if one is specified. === 72 # TODO: Handle precision 73 reference_fn = module_input.reference_fn 74 if reference_fn is not None: 75 ref_outputs = reference_fn(m, *args, **kwargs) 76 self.assertEqual(outputs, ref_outputs) 77 78 # === Use the method call and verify the parameters and buffers === 79 if dtype in dtype_to_method_caller: 80 dtype_to_method_caller[dtype](m) 81 m(*args, **kwargs) 82 self._assert_module_parameters_and_buffer_are(m, device, dtype) 83 84 # Tests passing factory kwargs (e.g. device / dtype) during module instantiation. 85 # They should be applied to any created parameters and buffers. 86 @modules(module_db) 87 def test_factory_kwargs(self, device, dtype, module_info, training): 88 module_cls = module_info.module_cls 89 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 90 requires_grad=False, training=training) 91 for module_input in module_inputs: 92 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 93 94 # Check if this module creates parameters or registers buffers. 95 # The mock magic here passes through to the real Parameter / register_buffer 96 # logic and is only used to check call inputs. 97 module_creates_params_or_buffers = False 98 parameter_new = mock_wrapper(torch.nn.Parameter.__new__) 99 with patch.object(torch.nn.Parameter, '__new__', parameter_new): 100 register_buffer = mock_wrapper(torch.nn.Module.register_buffer) 101 with patch.object(torch.nn.Module, 'register_buffer', register_buffer): 102 m = module_cls(*args, **kwargs) 103 m.train(training) 104 105 # Check if a parameter or buffer was created with a tensor not passed to the constructor. 106 constructor_tensors = get_tensors_from(args, kwargs) 107 for mock in [parameter_new.mock, register_buffer.mock]: 108 for call_args, call_kwargs in mock.call_args_list: 109 call_tensors = get_tensors_from(call_args, call_kwargs) 110 if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors): 111 module_creates_params_or_buffers = True 112 break 113 114 if not module_creates_params_or_buffers: 115 continue 116 117 # Instantiate module with the factory kwargs. 118 kwargs.update({ 119 'device': device, 120 'dtype': dtype, 121 }) 122 123 if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): 124 # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers. 125 uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__) 126 with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new): 127 uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__) 128 with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new): 129 m = module_cls(*args, **kwargs) 130 m.train(training) 131 uninit_param_new.mock.assert_has_calls( 132 [call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls]) 133 uninit_buffer_new.mock.assert_has_calls( 134 [call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls]) 135 else: 136 # Check device placement and dtype for created parameters and buffers. 137 # Only verify floating point dtypes since that's what the kwarg applies to. 138 m = module_cls(*args, **kwargs) 139 m.train(training) 140 self._assert_module_parameters_and_buffer_are(m, device, dtype) 141 142 @onlyCUDA 143 @modules(module_db) 144 def test_multiple_device_transfer(self, device, dtype, module_info, training): 145 module_cls = module_info.module_cls 146 module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 147 requires_grad=False, training=training) 148 module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, 149 requires_grad=False, training=training) 150 for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu): 151 if module_input_device.forward_input is None: 152 continue 153 154 with freeze_rng_state(): 155 # === Instantiate the module. === 156 args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs 157 m = module_cls(*args, **kwargs) 158 m.to(device).to(dtype) 159 m.train(training) 160 161 # === Do forward pass on GPU === 162 input_device_args = module_input_device.forward_input.args 163 input_device_kwargs = module_input_device.forward_input.kwargs 164 m(*input_device_args, **input_device_kwargs) 165 self._assert_module_parameters_and_buffer_are(m, device, dtype) 166 167 # === Move to CPU === 168 input_cpu_args = module_input_cpu.forward_input.args 169 input_cpu_kwargs = module_input_cpu.forward_input.kwargs 170 m.cpu() 171 m(*input_cpu_args, **input_cpu_kwargs) 172 self._assert_module_parameters_and_buffer_are(m, "cpu", dtype) 173 174 # === Move back to GPU and forward pass === 175 m.cuda() 176 m(*input_device_args, **input_device_kwargs) 177 self._assert_module_parameters_and_buffer_are(m, device, dtype) 178 179 if torch.cuda.device_count() >= 2: 180 # === test cross-GPU transfer works 181 def _to_device1(objs): 182 if isinstance(objs, (tuple, list)): 183 return type(objs)(_to_device1(item) for item in objs) 184 elif isinstance(objs, dict): 185 return {name: _to_device1(item) for name, item in objs.items()} 186 elif isinstance(objs, torch.Tensor): 187 return objs.cuda(1) 188 else: 189 return objs 190 input_device_1_args = _to_device1(input_device_args) 191 input_device_1_kwargs = _to_device1(input_device_kwargs) 192 193 m.cuda(1) 194 with torch.cuda.device(1): 195 m(*input_device_1_args, **input_device_1_kwargs) 196 self._assert_module_parameters_and_buffer_are(m, torch.device("cuda:1"), dtype) 197 198 @modules(module_db) 199 def test_repr(self, device, dtype, module_info, training): 200 # Test module can be represented with repr and str without errors. 201 module_cls = module_info.module_cls 202 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 203 requires_grad=False, training=training) 204 for module_input in module_inputs: 205 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 206 m = module_cls(*args, **kwargs) 207 m.to(device).to(dtype) 208 m.train(training) 209 210 # Check that these methods do not raise errors 211 m.__repr__() 212 str(m) 213 214 @modules(module_db) 215 def test_save_load(self, device, dtype, module_info, training): 216 # Test that module can be pickled and unpickled. 217 module_cls = module_info.module_cls 218 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 219 requires_grad=False, training=training) 220 for module_input in module_inputs: 221 if module_input.forward_input is None: 222 continue 223 224 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 225 226 with freeze_rng_state(): 227 # === Instantiate the module. === 228 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 229 m = module_cls(*args, **kwargs) 230 m.to(device).to(dtype) 231 m.train(training) 232 sd = m.state_dict() 233 234 # === Do forward pass. === 235 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 236 output = m(*args, **kwargs) 237 238 # === Check saved/loaded module gives the same output. === 239 with tempfile.TemporaryFile() as f: 240 torch.save(m, f) 241 f.seek(0) 242 # weights_only=False as this is legacy code that saves the model 243 m_copy = torch.load(f, weights_only=False) 244 output_from_copy = m_copy(*args, **kwargs) 245 self.assertEqual(output, output_from_copy) 246 247 # === Check saved/loaded state_dict are the same (including weights_only load). === 248 with tempfile.TemporaryFile() as f: 249 torch.save(sd, f) 250 f.seek(0) 251 sd_copy = torch.load(f) 252 self.assertEqual(sd_copy, sd) 253 del sd_copy 254 f.seek(0) 255 sd_copy_wo = torch.load(f, weights_only=True) 256 self.assertEqual(sd_copy_wo, sd) 257 258 @skipMeta 259 @modules([module_info for module_info in module_db 260 if 'inplace' in signature(module_info.module_cls).parameters]) 261 def test_check_inplace(self, device, dtype, module_info, training): 262 # Check if the inplace variant of the module gives the same result as the out of place 263 # variant. 264 module_cls = module_info.module_cls 265 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 266 requires_grad=True, training=training) 267 for module_input in module_inputs: 268 if module_input.forward_input is None: 269 continue 270 271 # === Instantiate the module. === 272 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 273 m_op = module_cls(*args, **kwargs, inplace=False) 274 m_op.to(device).to(dtype) 275 m_op.train(training) 276 m_inplace = module_cls(*args, **kwargs, inplace=True) 277 m_inplace.to(device).to(dtype) 278 m_inplace.train(training) 279 280 # === Inplace modules only supports inplace operations on the first argument === 281 input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 282 283 # === Do not allow the first input to be in input_kwargs === 284 forward_sig = signature(m_op).parameters 285 self.assertGreaterEqual(len(forward_sig), 1) 286 first_param_name = next(iter(forward_sig.items())) 287 self.assertNotIn(first_param_name, input_kwargs) 288 289 # === Out of place operation does not write to original tensor === 290 self.assertGreaterEqual(len(input_args), 1) 291 input_version = input_args[0]._version 292 with freeze_rng_state(): 293 output_op = m_op(*input_args, **input_kwargs) 294 self.assertEqual(input_args[0]._version, input_version) 295 296 # === Check that the inplace operation gives the same result === 297 input_arg_copy = deepcopy(input_args) 298 input_arg_clone = tuple(i.clone() for i in input_arg_copy) 299 input_clone_version = input_arg_clone[0]._version 300 with freeze_rng_state(): 301 output_ip = m_inplace(*input_arg_clone, **input_kwargs) 302 self.assertGreater(input_arg_clone[0]._version, input_clone_version) 303 self.assertEqual(output_op, output_ip) 304 305 # === Check that the gradients are the same === 306 grad = output_op.data.clone().normal_() 307 output_op.backward(grad) 308 output_ip.backward(grad) 309 self.assertEqual(input_args[0].grad, input_arg_copy[0].grad) 310 311 def _traverse_obj(self, obj, func): 312 if isinstance(obj, (tuple, list)): 313 return type(obj)(self._traverse_obj(o, func) for o in obj) 314 elif isgenerator(obj): 315 return tuple(self._traverse_obj(o, func) for o in obj) 316 elif isinstance(obj, dict): 317 return {name: self._traverse_obj(o, func) for name, o in obj.items()} 318 elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)): 319 return func(obj) 320 else: 321 return obj 322 323 def _retain_grad(self, obj): 324 # gradients needs to be retained to check for grad. This is useful when 325 # non-leafs are present in the graph. 326 def inner_retain_grad(obj): 327 if obj.requires_grad: 328 obj.retain_grad() 329 self._traverse_obj(obj, inner_retain_grad) 330 331 def _get_grads(self, obj): 332 def inner_get_grad(obj): 333 if obj.requires_grad: 334 return obj.grad 335 return self._traverse_obj(obj, inner_get_grad) 336 337 def _zero_grad(self, obj): 338 def inner_zero_grad(obj): 339 if obj.grad is not None: 340 obj.grad = None 341 self._traverse_obj(obj, inner_zero_grad) 342 343 @modules(module_db) 344 def test_non_contiguous_tensors(self, device, dtype, module_info, training): 345 # Check modules work with non-contiguous tensors 346 347 module_cls = module_info.module_cls 348 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 349 requires_grad=True, training=training) 350 351 def _make_non_contiguous(obj): 352 def inner_make_non_contiguous(obj): 353 # Scalar tensors can not be made non-contiguous 354 if not isinstance(obj, torch.Tensor) or obj.dim() == 0: 355 return obj 356 357 out = torch.repeat_interleave(obj, 2, dim=-1) 358 out = out[..., ::2].detach() 359 out.requires_grad = obj.requires_grad 360 return out 361 return self._traverse_obj(obj, inner_make_non_contiguous) 362 363 def _can_be_noncontiguous(obj): 364 if isinstance(obj, (tuple, list)): 365 return any(_can_be_noncontiguous(o) for o in obj) 366 elif isinstance(obj, dict): 367 return any(_can_be_noncontiguous(o) for o in obj.values()) 368 # scalar tensors can not be non-contiguous 369 return isinstance(obj, torch.Tensor) and obj.dim() != 0 370 371 for module_input in module_inputs: 372 if module_input.forward_input is None: 373 continue 374 375 input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 376 if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)): 377 continue 378 379 # === Instantiate the module. === 380 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 381 m = module_cls(*args, **kwargs) 382 m.to(device).to(dtype) 383 m.train(training) 384 385 self._retain_grad((input_args, input_kwargs)) 386 387 # === Forward with default input 388 with freeze_rng_state(): 389 default_output = m(*input_args, **input_kwargs) 390 if isinstance(default_output, torch.Tensor): 391 grad_output = default_output.clone().detach_().normal_() 392 default_output.backward(grad_output, retain_graph=True) 393 else: 394 grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_() if o.requires_grad else None) 395 for o in default_output) 396 flattened_default_output = torch.utils._pytree.tree_leaves(default_output) 397 flattened_grad_output = torch.utils._pytree.tree_leaves(grad_output) 398 for o, g_o in zip(flattened_default_output, flattened_grad_output): 399 if (o.requires_grad): 400 o.backward(g_o, retain_graph=True) 401 402 default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs))) 403 default_param_grad = deepcopy([p.grad for p in m.parameters()]) 404 405 # === Construct non-contiguous tensors === 406 nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs)) 407 nc_grad_output = _make_non_contiguous(grad_output) 408 409 # === Compare results with non-contiguous and contiguous tensors === 410 inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)] 411 grads = [grad_output, nc_grad_output] 412 413 for (in_args, in_kwargs), g_out in product(inputs, grads): 414 g_out_copy = deepcopy(g_out) 415 self._zero_grad((in_args, in_kwargs)) 416 self._zero_grad(m.parameters()) 417 418 with freeze_rng_state(): 419 out = m(*in_args, **in_kwargs) 420 if isinstance(out, torch.Tensor): 421 out.backward(g_out_copy, retain_graph=True) 422 else: 423 flattened_out = torch.utils._pytree.tree_leaves(out) 424 flattened_g_out_copy = torch.utils._pytree.tree_leaves(g_out_copy) 425 for o, g_o in zip(flattened_out, flattened_g_out_copy): 426 if o.requires_grad: 427 o.backward(g_o, retain_graph=True) 428 429 input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs)) 430 self.assertEqual(out, default_output) 431 self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0) 432 self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0) 433 434 param_grad = [p.grad for p in m.parameters()] 435 self.assertEqual(param_grad, default_param_grad) 436 437 def _test_gradients_helper(self, device, dtype, module_info, training, check): 438 # Check gradients 439 module_cls = module_info.module_cls 440 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 441 requires_grad=True, training=training) 442 # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled 443 gradcheck_nondet_tol = 0.0 444 if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled): 445 gradcheck_nondet_tol = module_info.gradcheck_nondet_tol 446 447 for module_input in module_inputs: 448 if module_input.forward_input is None: 449 continue 450 451 # === Instantiate the module. === 452 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 453 m = module_cls(*args, **kwargs) 454 m.to(device).to(dtype) 455 m.train(training) 456 457 params = tuple(m.parameters()) 458 459 # === Lazy modules need to see an input to initialize params before gradcheck is run. === 460 input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 461 if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): 462 with torch.no_grad(): 463 m(*input_args, **input_kwargs) 464 465 # === Perform gradient check on the input_args === 466 other_kwargs = {} 467 kwarg_tensors = [] 468 for name, obj in input_kwargs.items(): 469 if isinstance(obj, torch.Tensor): 470 kwarg_tensors.append((name, obj)) 471 else: 472 other_kwargs[name] = obj 473 474 def fn_to_gradcheck(*flat_input_and_params): 475 input_and_params = torch.utils._pytree.tree_unflatten(flat_input_and_params, flat_spec) 476 new_input_args = input_and_params[:len(input_args)] 477 kwarg_args = input_and_params[-len(kwarg_tensors):] 478 new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)} 479 480 with freeze_rng_state(): 481 output = m(*new_input_args, **new_kwargs, **other_kwargs) 482 output_flattened = torch.utils._pytree.tree_leaves(output) 483 return output_flattened 484 485 # check total derivative 486 grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors) 487 flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input) 488 489 self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) 490 491 # check partial derivatives 492 old_params_requires_grad = [p.requires_grad for p in params] 493 for p in params: 494 p.requires_grad = False 495 496 old_kwargs_requires_grad = [obj.requires_grad for (_, obj) in kwarg_tensors] 497 for (_, obj) in kwarg_tensors: 498 obj.requires_grad = False 499 500 for p, old in zip(params, old_params_requires_grad): 501 p.requires_grad = old 502 grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors) 503 flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input) 504 self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) 505 p.requires_grad = False 506 507 for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad): 508 obj.requires_grad = old 509 grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors) 510 flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input) 511 self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) 512 obj.requires_grad = False 513 514 @modules(module_db, allowed_dtypes=[torch.double]) 515 def test_grad(self, device, dtype, module_info, training): 516 self._test_gradients_helper(device, dtype, module_info, training, gradcheck) 517 518 @modules([m for m in module_db if m.supports_gradgrad], 519 allowed_dtypes=[torch.double]) 520 def test_gradgrad(self, device, dtype, module_info, training): 521 self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck) 522 523 @onlyCUDA 524 @with_tf32_off # Turn off TF32 to compute at full precision https://github.com/pytorch/pytorch/issues/86798 525 @toleranceOverride({torch.float32: tol(5e-2, 0), 526 torch.float64: tol(4e-4, 0)}) 527 @modules(module_db) 528 def test_cpu_gpu_parity(self, device, dtype, module_info, training): 529 # TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a 530 # nicer way for eval mode only. 531 # See https://github.com/pytorch/pytorch/issues/79161 532 rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM} 533 if (module_info.module_cls in rnn_modules 534 and not training 535 and 'cuda' in device 536 and torch.backends.cudnn.enabled): 537 return 538 539 # Test cpu and gpu results are the same 540 module_cls = module_info.module_cls 541 module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, 542 requires_grad=True, training=training) 543 544 def _to_device(obj): 545 if isinstance(obj, torch.Tensor): 546 res = obj.detach().to(device=device) 547 res.requires_grad = obj.requires_grad 548 return res 549 elif isinstance(obj, tuple): 550 return tuple(_to_device(o) for o in obj) 551 elif isinstance(obj, dict): 552 return {key: _to_device(o) for key, o in obj.items()} 553 else: 554 return deepcopy(obj) 555 556 for module_input in module_inputs_cpu: 557 # === Move input from cpu to device === 558 cpu_forward_args = module_input.forward_input.args 559 cpu_forward_kwargs = module_input.forward_input.kwargs 560 561 gpu_forward_args, gpu_forward_kwargs = _to_device((cpu_forward_args, cpu_forward_kwargs)) 562 563 self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs)) 564 565 # === Construct module on cpu and gpu === 566 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 567 568 cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu") 569 cpu_module.train(training) 570 gpu_module = module_cls(*args, **kwargs).to(dtype).to(device) 571 gpu_module.train(training) 572 573 # === Lazy modules need to see an input to initialize params === 574 if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin): 575 with torch.no_grad(): 576 cpu_module(*cpu_forward_args, **cpu_forward_kwargs) 577 gpu_module(*gpu_forward_args, **gpu_forward_kwargs) 578 579 for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): 580 gpu_p.data.copy_(cpu_p) 581 582 # === Compare forward output between cpu and gpu === 583 cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs) 584 gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs) 585 586 self.assertEqual(cpu_outputs, gpu_outputs) 587 588 # === Run backwards on CPU and GPU and compare results === 589 def check_backward(cpu_output, gpu_output): 590 cpu_grad_output = cpu_output.clone().normal_() 591 gpu_grad_output = cpu_grad_output.type_as(gpu_output) 592 593 cpu_output.backward(cpu_grad_output, retain_graph=True) 594 gpu_output.backward(gpu_grad_output, retain_graph=True) 595 596 cpu_grad_input = self._get_grads(cpu_forward_args) 597 gpu_grad_input = self._get_grads(gpu_forward_args) 598 self.assertEqual(cpu_grad_input, gpu_grad_input) 599 600 for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): 601 self.assertEqual(cpu_p.grad, gpu_p.grad) 602 603 cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs) 604 gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs) 605 self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input) 606 607 for _ in range(5): 608 if isinstance(cpu_outputs, torch.Tensor): 609 check_backward(cpu_outputs, gpu_outputs) 610 else: 611 flatten_cpu_outputs = torch.utils._pytree.tree_leaves(cpu_outputs) 612 flatten_gpu_outputs = torch.utils._pytree.tree_leaves(gpu_outputs) 613 for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs): 614 if cpu_output.requires_grad: 615 check_backward(cpu_output, gpu_output) 616 617 @with_tf32_off 618 @modules(module_db) 619 def test_memory_format(self, device, dtype, module_info, training): 620 is_sm86or80 = device.startswith("cuda") and (torch.cuda.get_device_capability(0) == (8, 6) 621 or torch.cuda.get_device_capability(0) == (8, 0)) 622 # TODO tighten it to a specific module 623 atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None) 624 module_cls = module_info.module_cls 625 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 626 requires_grad=True, training=training) 627 module_memformat_affects_out = module_info.module_memformat_affects_out 628 629 def _get_mem_formats(channels_last=False, channels_last_3d=False): 630 if channels_last: 631 return ([torch.contiguous_format, torch.channels_last], 632 [torch.preserve_format, torch.contiguous_format, torch.channels_last]) 633 elif channels_last_3d: 634 return ([torch.contiguous_format, torch.channels_last_3d], 635 [torch.preserve_format, torch.contiguous_format, torch.channels_last_3d]) 636 else: 637 return ([torch.contiguous_format], 638 [torch.preserve_format, torch.contiguous_format]) 639 640 # Check that at least one Tensor input has dim == n 641 def _check_dims(obj, n): 642 if isinstance(obj, torch.Tensor): 643 return obj.dim() == n 644 elif isinstance(obj, (tuple, list)): 645 return any(_check_dims(o, n) for o in obj) 646 else: 647 return False 648 649 # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format 650 def _to_mem_format(mem_format, obj): 651 def inner_to_mem_format(obj): 652 d = obj.dim() 653 if ((mem_format == torch.channels_last and d != 4) 654 or (mem_format == torch.channels_last_3d and d != 5)): 655 return obj.clone().detach().requires_grad_(obj.requires_grad) 656 return obj.clone().to(memory_format=mem_format).detach().requires_grad_(obj.requires_grad) 657 658 return self._traverse_obj(obj, inner_to_mem_format) 659 660 def _check_out_mem_format(output, input_mem_format, module_mem_format): 661 def inner_check_out_mem_format(output): 662 d = output.dim() 663 if (d == 4 and ((input_mem_format == torch.channels_last) 664 or (module_mem_format == torch.channels_last and module_memformat_affects_out))): 665 self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last)) 666 elif (d == 5 and ((input_mem_format == torch.channels_last_3d) 667 or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))): 668 self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d)) 669 else: 670 self.assertTrue(output.is_contiguous()) 671 return self._traverse_obj(output, inner_check_out_mem_format) 672 673 def _req_grad(t): 674 return isinstance(t, torch.Tensor) and t.requires_grad 675 676 for module_input in module_inputs: 677 if module_input.forward_input is None: 678 continue 679 680 supports_channels_last = _check_dims(module_input.forward_input.args, 4) 681 supports_channels_last_3d = _check_dims(module_input.forward_input.args, 5) 682 input_mem_formats, module_mem_formats = _get_mem_formats(supports_channels_last, supports_channels_last_3d) 683 684 with freeze_rng_state(): 685 # === Instantiate the module. === 686 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 687 688 m = module_cls(*args, **kwargs) 689 m.to(device).to(dtype) 690 m.train(training) 691 692 # === Get output in (contiguous, contiguous) configuration. === 693 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 694 desired_outputs = m(*args, **kwargs) 695 # === Do backward pass. === 696 ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(desired_outputs) if _req_grad(t)) 697 if training and len(ref_diff_outputs) > 0: 698 params = tuple(p for p in m.parameters()) 699 ref_diff_inputs = tuple( 700 t 701 for t in torch.utils._pytree.tree_leaves((args, kwargs, params)) 702 if _req_grad(t) 703 ) 704 ref_grad_outputs = tuple( 705 torch.rand_like(t) 706 for t in ref_diff_outputs 707 ) 708 ref_grad_inputs = torch.autograd.grad( 709 ref_diff_outputs, 710 ref_diff_inputs, 711 grad_outputs=ref_grad_outputs, 712 ) 713 714 for input_mem_format in input_mem_formats: 715 # === Change memformat of input. === 716 d_args = _to_mem_format(input_mem_format, module_input.forward_input.args) 717 d_kwargs = _to_mem_format(input_mem_format, module_input.forward_input.kwargs) 718 719 # See https://github.com/pytorch/pytorch/issues/107861 720 # When inductor tests are turned on, the setting of requires_grad will be lost 721 for t1, t2 in zip( 722 torch.utils._pytree.tree_leaves(d_args), 723 torch.utils._pytree.tree_leaves(module_input.forward_input.args), 724 ): 725 t1.requires_grad_(t2.requires_grad) 726 for t1, t2 in zip( 727 torch.utils._pytree.tree_leaves(d_kwargs), 728 torch.utils._pytree.tree_leaves(module_input.forward_input.kwargs), 729 ): 730 t1.requires_grad_(t2.requires_grad) 731 732 module_input.forward_input.args = d_args 733 module_input.forward_input.kwargs = d_kwargs 734 735 for module_mem_format in module_mem_formats: 736 # === Change memformat of module === 737 m.to(memory_format=module_mem_format) 738 739 # === Do forward pass. === 740 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 741 outputs = m(*args, **kwargs) 742 743 # === Compare outputs to (contiguous, contiguous) output. === 744 if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format: 745 self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol) 746 747 # === Check mem format of output. === 748 _check_out_mem_format(outputs, input_mem_format, module_mem_format) 749 750 # === Do backward pass. === 751 diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(outputs) if _req_grad(t)) 752 if training and len(diff_outputs) > 0: 753 params = tuple(p for p in m.parameters()) 754 diff_inputs = tuple( 755 t 756 for t in torch.utils._pytree.tree_leaves((args, kwargs, params)) 757 if _req_grad(t) 758 ) 759 grad_outputs = tuple( 760 torch.empty_like(t1).copy_(t2) 761 for (t1, t2) in zip(diff_outputs, ref_grad_outputs) 762 ) 763 764 grad_inputs = torch.autograd.grad( 765 diff_outputs, 766 diff_inputs, 767 grad_outputs=grad_outputs, 768 ) 769 770 if ( 771 input_mem_format != torch.contiguous_format 772 or module_mem_format != torch.contiguous_format 773 ): 774 self.assertEqual( 775 grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol 776 ) 777 778 # === Check mem format of grad_inputs. === 779 _check_out_mem_format(grad_inputs, input_mem_format, module_mem_format) 780 781 # Test whether train and eval modes differ for each module. Use to verify 782 # that the ModuleInfo entry flag is correct. 783 @modules(module_db, train_eval_mode=TrainEvalMode.train_only) 784 def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training): 785 module_cls = module_info.module_cls 786 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 787 requires_grad=False, training=training) 788 789 # Run forward inputs through to see if the training flag is accessed during forward. 790 for module_input in module_inputs: 791 if module_input.forward_input is None: 792 continue 793 794 # === Instantiate the module. === 795 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 796 m = module_cls(*args, **kwargs) 797 m.to(device).to(dtype) 798 m.train(training) 799 800 # Remove training attribute and see if forward still works. 801 delattr(m, 'training') 802 803 # === Do forward pass. === 804 try: 805 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 806 m(*args, **kwargs) 807 except AttributeError as e: 808 if "'training'" in str(e): 809 self.assertTrue(module_info.train_and_eval_differ, 810 f"The ModuleInfo entry for {module_info.name} has " 811 "train_and_eval_differ=False, but the training mode was found to " 812 "affect the forward pass. Consider setting train_and_eval_differ=True " 813 "for this ModuleInfo entry.") 814 else: 815 raise e 816 817 818 @onlyCPU 819 @modules(module_db) 820 def test_device_ctx_init(self, device, dtype, module_info, training): 821 module_cls = module_info.module_cls 822 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 823 requires_grad=False, training=training) 824 with torch.device('meta'): 825 module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype, 826 requires_grad=False, training=training) 827 828 for module_input, module_input_meta in zip(module_inputs, module_inputs_meta): 829 c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 830 831 c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs 832 833 m_cpu = module_cls(*c_args, **c_kwargs) 834 835 with torch.device('meta'): 836 m = module_cls(*c_args_meta, **c_kwargs_meta) 837 838 for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()), 839 zip(m.buffers(), m_cpu.buffers())): 840 if torch.nn.parameter.is_lazy(p_meta): 841 continue 842 self.assertTrue(p_meta.is_meta) 843 assert_metadata_eq(self.assertEqual, p_meta, p_cpu) 844 845 846 @modules([module for module in module_db if module.module_error_inputs_func is not None]) 847 def test_errors(self, device, dtype, module_info, training): 848 module_cls = module_info.module_cls 849 error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype, 850 requires_grad=False, training=training) 851 for error_input in error_inputs: 852 module_input = error_input.module_error_input 853 c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 854 if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR: 855 with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): 856 m = module_cls(*c_args, **c_kwargs) 857 elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR: 858 m = module_cls(*c_args, **c_kwargs) 859 fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 860 with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): 861 m(*fw_args, **fw_kwargs) 862 else: 863 raise NotImplementedError(f"Unknown error type {error_input.error_on}") 864 865 # Only run this test for float32 because the test loops over all the dtypes 866 @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) 867 @parametrize('swap', [True, False]) 868 @parametrize('set_grad', [True, False]) 869 @wrapSwapTensorsTest() 870 def test_to(self, device, dtype, module_info, training, swap, set_grad): 871 module_cls = module_info.module_cls 872 devices = ['cpu'] 873 if torch.cuda.is_available(): 874 devices += ['cuda'] 875 dtypes = module_info.dtypes 876 module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, 877 requires_grad=False, training=training) 878 torch.__future__.set_swap_module_params_on_conversion(swap) 879 880 for module_input in module_inputs: 881 c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 882 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs 883 884 m = module_cls(*c_args, **c_kwargs) 885 886 # Avoid using `module.to()` when constructing module since that is the method we are testing 887 def _to(m, set_grad=False): 888 for c in m.children(): 889 _to(c, set_grad=set_grad) 890 for n, p in m.named_parameters(recurse=False): 891 new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype)) 892 setattr(m, n, new_p) 893 if set_grad: 894 new_p.grad = torch.randn_like(new_p) 895 for n, b in m.named_buffers(recurse=False): 896 new_b = b.detach().clone().to(device, dtype) 897 setattr(m, n, new_b) 898 _to(m, set_grad=set_grad) 899 900 # Check .to() can be run after forward and backward with swap 901 has_params = len(list(m.parameters())) > 0 902 if swap and not set_grad and has_params: 903 out = m(*args, **kwargs) 904 if isinstance(out, tuple): 905 out = out[0] 906 out.sum().backward() 907 m.to(dtype=torch.half) 908 # reset 909 m.to(dtype=torch.float32) 910 911 prev_device, prev_dtype = device, dtype 912 for device_, dtype_ in product(devices, dtypes): 913 # if device/dtype do not change, grad.to(device, dtype) is a no-op so 914 # swapping will not change ._cdata 915 # parameters will be wrapped in an nn.Parameter before swapping 916 # which will cause the ._cdata to change 917 g_no_swap = device_ == prev_device and dtype_ == prev_dtype 918 prev_prev_device, prev_prev_dtype = prev_device, prev_dtype 919 prev_device, prev_dtype = device_, dtype_ 920 921 p_ids_before = [id(p) for p in m.parameters()] 922 p_cdatas_before = [p._cdata for p in m.parameters()] 923 if set_grad: 924 g_ids_before = [id(p.grad) for p in m.parameters()] 925 g_cdatas_before = [p.grad._cdata for p in m.parameters()] 926 927 m.to(device=device_, dtype=dtype_) 928 929 self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters())) 930 self.assertTrue(all(p.device.type == device_ for p in m.parameters())) 931 self.assertTrue(all(p.dtype == dtype_ for p in m.parameters())) 932 p_ids_after = [id(p) for p in m.parameters()] 933 p_cdatas_after = [p._cdata for p in m.parameters()] 934 935 if set_grad: 936 self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters())) 937 self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters())) 938 g_ids_after = [id(p.grad) for p in m.parameters()] 939 g_cdatas_after = [p.grad._cdata for p in m.parameters()] 940 941 if swap: 942 # id same, ._cdata differs --> swapped cdata of THPVariable 943 self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) 944 self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) 945 if set_grad: 946 self.assertTrue( 947 all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after))) 948 else: 949 # id and _cdata remain the same --> .data setting 950 self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after))) 951 self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) 952 if set_grad: 953 self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after))) 954 self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after))) 955 956 @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) 957 @parametrize('swap', [True, False]) 958 @wrapSwapTensorsTest() 959 def test_to_empty(self, device, dtype, module_info, swap, training): 960 module_cls = module_info.module_cls 961 962 with torch.device("meta"): 963 module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype, 964 requires_grad=False, training=training) 965 966 torch.__future__.set_swap_module_params_on_conversion(swap) 967 device_ = torch.device(device) 968 969 for module_input in module_inputs: 970 c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs 971 972 with torch.device("meta"): 973 m = module_cls(*c_args, **c_kwargs) 974 975 p_ids_before = [id(p) for p in m.parameters()] 976 p_cdatas_before = [p._cdata for p in m.parameters()] 977 m.to_empty(device=device_) 978 979 self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters())) 980 self.assertTrue(all(p.device == device_ for p in m.parameters())) 981 self.assertTrue(all(p.dtype == dtype for p in m.parameters())) 982 p_ids_after = [id(p) for p in m.parameters()] 983 p_cdatas_after = [p._cdata for p in m.parameters()] 984 985 if swap: 986 # id same, ._cdata differs --> swapped cdata of THPVariable 987 self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) 988 self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) 989 else: 990 # id and ._cdata differ 991 # meta and device have different shallow copy types, so this will create a new 992 # parameter and assign it to the module 993 self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after))) 994 self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) 995 996 997instantiate_device_type_tests(TestModule, globals(), allow_mps=True) 998 999if __name__ == '__main__': 1000 run_tests() 1001