1# Owner(s): ["module: nn"] 2import unittest 3from dataclasses import dataclass 4from functools import partial 5from itertools import chain, product 6 7import torch 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.nn import CrossEntropyLoss 11from torch.nn.utils._expanded_weights import ExpandedWeight 12from torch.nn.utils._expanded_weights.expanded_weights_utils import ( 13 forward_helper, 14 set_grad_sample_if_exists, 15 standard_kwargs, 16 sum_over_all_but_batch_and_last_n, 17 unpack_expanded_weight_or_tensor, 18) 19from torch.nn.utils._per_sample_grad import call_for_per_sample_grads 20from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off 21from torch.testing._internal.common_device_type import ( 22 instantiate_device_type_tests, 23 OpDTypes, 24 ops, 25) 26from torch.testing._internal.common_methods_invocations import op_db, SampleInput 27from torch.testing._internal.common_modules import module_db, modules 28from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase 29from torch.testing._internal.common_utils import ( 30 freeze_rng_state, 31 make_tensor, 32 parametrize, 33 run_tests, 34 skipIfTorchDynamo, 35 TestCase, 36) 37from torch.utils._pytree import tree_map_only 38 39 40class TestContext: 41 pass 42 43 44class TestExpandedWeightHelperFunction(TestCase): 45 def test_forward_helper(self, device): 46 input = torch.randn(3, 4, device=device) 47 weight = torch.randn(5, 4, device=device) 48 bias = torch.randn(5, device=device) 49 for weight_batched, bias_batched in product([True, False], [True, False]): 50 maybe_batched_weight = weight 51 maybe_batched_bias = bias 52 if weight_batched: 53 maybe_batched_weight = ExpandedWeight( 54 weight.clone().requires_grad_(), 3, loss_reduction="sum" 55 ) 56 if bias_batched: 57 maybe_batched_bias = ExpandedWeight( 58 bias.clone().requires_grad_(), 3, loss_reduction="sum" 59 ) 60 args = (input, maybe_batched_weight, maybe_batched_bias) 61 expanded_args, expanded_kwargs = standard_kwargs(("bias",), args) 62 res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 63 expected = nn.functional.linear(input, weight, bias) 64 self.assertEqual(res, expected) 65 66 self.assertEqual(len(expanded_args), 2) 67 assert expanded_args[0] is args[0] # avoids property checks in assertEquals 68 assert expanded_args[1] is args[1] # avoids property checks in assertEquals 69 self.assertEqual(len(expanded_kwargs), 1) 70 assert ( 71 expanded_kwargs["bias"] is args[2] 72 ) # avoids property checks in assertEquals 73 74 def test_forward_helper_failure_args(self, device): 75 weight = torch.randn(5, 4, device=device) 76 bias = torch.randn(5, device=device) 77 with self.assertRaisesRegex( 78 RuntimeError, r"do not support inputs that are also ExpandedWeights." 79 ): 80 input = ExpandedWeight( 81 torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum" 82 ) 83 expanded_args, expanded_kwargs = standard_kwargs( 84 ("bias",), (input, weight, bias) 85 ) 86 forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 87 with self.assertRaisesRegex( 88 RuntimeError, r"requires a Tensor as the first input" 89 ): 90 expanded_args, expanded_kwargs = standard_kwargs( 91 ("bias",), (3, weight, bias) 92 ) 93 forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 94 with self.assertRaisesRegex( 95 RuntimeError, r"requires a batch dimension but got an input of size 0" 96 ): 97 expanded_args, expanded_kwargs = standard_kwargs( 98 ("bias",), (torch.tensor(3), weight, bias) 99 ) 100 forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 101 with self.assertRaisesRegex( 102 RuntimeError, r"0 is not a valid batch size for Expanded Weights" 103 ): 104 expanded_args, expanded_kwargs = standard_kwargs( 105 ("bias",), (torch.randn(0, 1, 2), weight, bias) 106 ) 107 forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 108 input = torch.randn(3, 4) 109 for weight_batched, bias_batched in product([True, False], [True, False]): 110 if not weight_batched and not bias_batched: 111 continue 112 maybe_batched_weight = weight 113 maybe_batched_bias = bias 114 if weight_batched: 115 maybe_batched_weight = ExpandedWeight( 116 weight.clone().requires_grad_(), 4, loss_reduction="sum" 117 ) 118 if bias_batched: 119 maybe_batched_bias = ExpandedWeight( 120 bias.clone().requires_grad_(), 4, loss_reduction="sum" 121 ) 122 with self.assertRaisesRegex( 123 RuntimeError, 124 r"Expected ExpandedWeights to have batch size matching input", 125 ): 126 expanded_args, expanded_kwargs = standard_kwargs( 127 ("bias",), (input, maybe_batched_weight, maybe_batched_bias) 128 ) 129 forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 130 131 def test_set_grad_sample_if_exists(self, device): 132 def test_fn(a): 133 return grad_sample 134 135 orig_weight = torch.randn(4, device=device, requires_grad=True) 136 expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum") 137 grad_sample = torch.randn(3) 138 set_grad_sample_if_exists(expanded_weight, test_fn) 139 self.assertTrue(hasattr(orig_weight, "grad_sample")) 140 self.assertEqual(orig_weight.grad_sample, grad_sample) 141 142 basic_tensor = torch.randn(4, device=device) 143 set_grad_sample_if_exists(basic_tensor, test_fn) 144 self.assertFalse(hasattr(basic_tensor, "grad_sample")) 145 146 non_tensor = 3 147 set_grad_sample_if_exists(non_tensor, test_fn) 148 self.assertFalse(hasattr(non_tensor, "grad_sample")) 149 150 def test_set_grad_sample_if_exists_failure(self, device): 151 def test_fn(a): 152 return True 153 154 grad_tensor = torch.randn(4, requires_grad=True, device=device) 155 with self.assertRaisesRegex( 156 RuntimeError, 157 r"does not support a mixture of ExpandedWeight parameters and normal Parameters", 158 ): 159 set_grad_sample_if_exists(grad_tensor, test_fn) 160 161 def test_unpack_expanded_weight_or_tensor(self, device): 162 input = torch.randn(3, requires_grad=True, device=device) 163 self.assertEqual( 164 input, 165 unpack_expanded_weight_or_tensor( 166 ExpandedWeight(input, 3, loss_reduction="sum") 167 ), 168 ) 169 170 input.requires_grad_(False) 171 self.assertEqual(input, unpack_expanded_weight_or_tensor(input)) 172 self.assertTrue(unpack_expanded_weight_or_tensor(4) is None) 173 174 def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device): 175 input = torch.randn(3, requires_grad=True, device=device) 176 self.assertTrue( 177 unpack_expanded_weight_or_tensor( 178 ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input 179 ) 180 ) 181 182 input.requires_grad_(False) 183 self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input)) 184 self.assertTrue( 185 unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None 186 ) 187 188 def test_unpack_expanded_weight_or_tensor_failure(self, device): 189 input = torch.randn(3, requires_grad=True, device=device) 190 with self.assertRaisesRegex( 191 RuntimeError, 192 r"does not support a mixture of ExpandedWeight parameters and normal Parameters", 193 ): 194 unpack_expanded_weight_or_tensor(input) 195 196 with self.assertRaisesRegex( 197 RuntimeError, 198 r"does not support a mixture of ExpandedWeight parameters and normal Parameters", 199 ): 200 unpack_expanded_weight_or_tensor(input, lambda x: x is input) 201 202 def test_sum_over_all_but_batch_and_last_n(self, device): 203 input = torch.randn(1, 2, 3, 4, 5, device=device) 204 res = sum_over_all_but_batch_and_last_n(input, 2) 205 expected = input.sum((1, 2)) 206 self.assertEqual(res, expected) 207 208 res = sum_over_all_but_batch_and_last_n(input, 0) 209 expected = input.sum((1, 2, 3, 4)) 210 self.assertEqual(res, expected) 211 212 res = sum_over_all_but_batch_and_last_n(input, 4) 213 self.assertEqual(res, input) 214 215 216class TestExpandedWeightFunctional(TestCase): 217 def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction): 218 input = sample_input.input 219 args = sample_input.args 220 kwargs = sample_input.kwargs 221 batch_size = input.shape[0] if len(input.shape) > 1 else 1 222 223 # get per sample grads with ExpandedWeights objects 224 loss_reduction = "sum" if reduction == torch.sum else "mean" 225 (ew_input, ew_args, ew_kwargs) = make_expanded_weight( 226 sample_input, batch_size, loss_reduction 227 ) 228 diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) 229 diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] 230 diff_input_list = [ 231 i.orig_weight if isinstance(i, ExpandedWeight) else i 232 for i in diff_input_list 233 ] 234 if not diff_input_list: 235 return 236 result = run_op(op, ew_input, *ew_args, **ew_kwargs) 237 reduction( 238 result 239 ).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ 240 expanded_weight_grad = tuple( 241 i.grad_sample if hasattr(i, "grad_sample") else i.grad 242 for i in diff_input_list 243 ) 244 245 # get per sample grads with for loop 246 func = partial(run_op, op) 247 248 per_sample_grad = for_loop_per_sample_grad( 249 batch_size, reduction, input, func, *args, **kwargs 250 ) 251 252 # check equality 253 self.assertEqual(len(per_sample_grad), len(expanded_weight_grad)) 254 if loss_reduction == "mean": 255 # don't check equality of `input.grad`s since these vanilla tensors won't be scaled 256 expanded_weight_grad = expanded_weight_grad[1:] 257 per_sample_grad = per_sample_grad[1:] 258 for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad): 259 self.assertEqual(result_grad, expected_grad) 260 261 @ops( 262 filter(lambda op: op.supports_expanded_weight, op_db), 263 dtypes=OpDTypes.supported, 264 allowed_dtypes=(torch.double,), 265 ) 266 def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op): 267 sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 268 for sample_input in supported_inputs(op, sample_inputs): 269 if ( 270 op.name == "nn.functional.embedding" 271 ): # embedding flips its argument order for autograd tests 272 sample_input = SampleInput( 273 sample_input.args[0], 274 args=(sample_input.input,), 275 kwargs=sample_input.kwargs, 276 ) 277 278 self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum) 279 280 @ops( 281 filter(lambda op: op.supports_expanded_weight, op_db), 282 dtypes=OpDTypes.supported, 283 allowed_dtypes=(torch.double,), 284 ) 285 def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op): 286 sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 287 for sample_input in supported_inputs(op, sample_inputs): 288 if ( 289 op.name == "nn.functional.embedding" 290 ): # embedding flips its argument order for autograd tests 291 sample_input = SampleInput( 292 sample_input.args[0], 293 args=(sample_input.input,), 294 kwargs=sample_input.kwargs, 295 ) 296 297 self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) 298 299 @ops( 300 filter(lambda op: op.supports_expanded_weight, op_db), 301 dtypes=OpDTypes.supported, 302 allowed_dtypes=(torch.double,), 303 ) 304 def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op): 305 sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 306 for sample_input in supported_inputs(op, sample_inputs): 307 if ( 308 op.name == "nn.functional.embedding" 309 ): # embedding flips its argument order for autograd tests 310 sample_input = SampleInput( 311 sample_input.args[0], 312 args=(sample_input.input,), 313 kwargs=sample_input.kwargs, 314 ) 315 sample_input.input.requires_grad_(False) 316 317 self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) 318 319 @skipIfTorchDynamo("Checking error message doesn't work with dynamo") 320 @ops( 321 filter(lambda op: op.supports_expanded_weight, op_db), 322 dtypes=OpDTypes.supported, 323 allowed_dtypes=(torch.double,), 324 ) 325 def test_unsupported_expand_weights(self, device, dtype, op): 326 sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 327 unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False) 328 for sample_input in unsupported_inputs: 329 with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"): 330 if ( 331 op.name == "nn.functional.embedding" 332 ): # embedding flips its argument order for autograd tests 333 sample_input = SampleInput( 334 sample_input.args[0], 335 args=(sample_input.input,), 336 kwargs=sample_input.kwargs, 337 ) 338 input = sample_input.input 339 340 batch_size = input.shape[0] if len(input.shape) > 1 else 1 341 342 # get per sample grads with ExpandedWeights objects 343 (ew_input, ew_args, ew_kwargs) = make_expanded_weight( 344 sample_input, batch_size 345 ) 346 result = run_op(op, ew_input, *ew_args, **ew_kwargs) 347 diff_input_list = ( 348 (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) 349 ) 350 diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] 351 diff_input_list = [ 352 i.orig_weight if isinstance(i, ExpandedWeight) else i 353 for i in diff_input_list 354 ] 355 result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ 356 357 @ops( 358 filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported 359 ) 360 def test_expanded_weight_forward(self, device, dtype, op): 361 sample_inputs = op.sample_inputs(device, dtype) 362 for sample_input in supported_inputs(op, sample_inputs): 363 if ( 364 op.name == "nn.functional.embedding" 365 ): # embedding flips its argument order for autograd tests 366 sample_input = SampleInput( 367 sample_input.args[0].clone(), 368 args=(sample_input.input.clone(),), 369 kwargs=sample_input.kwargs, 370 ) 371 if ( 372 "cuda" in device 373 and "max_norm" in sample_input.kwargs 374 and "padding_idx" in sample_input.kwargs 375 ): 376 self.skipTest( 377 "embedding is non-determinstic in this case, see issue #74679" 378 ) 379 batch_size = ( 380 sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1 381 ) 382 for loss_reduction in ["sum", "mean"]: 383 (ew_input, ew_args, ew_kwargs) = make_expanded_weight( 384 sample_input, batch_size, loss_reduction 385 ) 386 expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs) 387 normal_result = run_op( 388 op, sample_input.input, *sample_input.args, **sample_input.kwargs 389 ) 390 self.assertEqual(expanded_weight_result, normal_result) 391 392 def test_expanded_weight_error(self, device): 393 batch_size = 3 394 sample_input = make_tensor( 395 (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True 396 ) 397 sample_weight = make_tensor( 398 (4), dtype=torch.float32, device=device, requires_grad=True 399 ) 400 with self.assertRaisesRegex( 401 RuntimeError, r"Expanded Weights encountered but cannot handle function" 402 ): 403 torch.add( 404 sample_input, 405 ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"), 406 ) 407 408 def _test_embedding_model(self, model, num_embedding, device): 409 batch_size = 32 410 input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device) 411 return self._test_model( 412 partial(model, num_embedding=num_embedding), batch_size, input, device 413 ) 414 415 def _test_conv_model( 416 self, 417 model, 418 input_size, 419 num_dim, 420 device, 421 loss_reduction="sum", 422 atol=1e-4, 423 rtol=5e-5, 424 ): 425 batch_size = 32 426 input_ending = [input_size] * num_dim 427 input = torch.randn([batch_size, 3] + input_ending, device=device) 428 return self._test_model( 429 partial(model, num_dim=num_dim), 430 batch_size, 431 input, 432 device, 433 loss_reduction, 434 atol, 435 rtol, 436 ) 437 438 def _test_model( 439 self, 440 model, 441 batch_size, 442 input, 443 device, 444 loss_reduction="sum", 445 atol=1e-4, 446 rtol=5e-5, 447 ): 448 model = model(10).to(device) 449 targets = torch.randint(0, 10, (batch_size,), device=device) 450 criterion = CrossEntropyLoss(reduction=loss_reduction) 451 result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input) 452 loss = criterion(result, targets) 453 loss.backward() 454 result = [] 455 for weight in model.parameters(): 456 result.append(weight.grad_sample) 457 del weight.grad_sample 458 459 expected = [] 460 for i in range(batch_size): 461 loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) 462 expected.append( 463 torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss)) 464 ) 465 466 expected = [torch.stack(grad) for grad in zip(*expected)] 467 for res, exp in zip(result, expected): 468 self.assertEqual(res, exp, atol=atol, rtol=rtol) 469 470 def _compute_tolerances(self, device): 471 is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability( 472 0 473 ) == (8, 6) 474 return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5) 475 476 @tf32_off() 477 def test_cnn_model_sum(self, device): 478 def convnet(num_classes, num_dim): 479 return nn.Sequential( 480 nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), 481 nn.ReLU(), 482 nn.AvgPool2d(kernel_size=2, stride=2), 483 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 484 nn.ReLU(), 485 nn.AvgPool2d(kernel_size=2, stride=2), 486 nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 487 nn.ReLU(), 488 nn.AvgPool2d(kernel_size=2, stride=2), 489 nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 490 nn.ReLU(), 491 nn.AdaptiveAvgPool2d((1, 1)), 492 nn.Flatten(start_dim=1, end_dim=-1), 493 nn.Linear(128, num_classes, bias=True), 494 ) 495 496 atol, rtol = self._compute_tolerances(device) 497 return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol) 498 499 @tf32_off() 500 def test_cnn_model_mean(self, device): 501 def convnet(num_classes, num_dim): 502 return nn.Sequential( 503 nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), 504 nn.ReLU(), 505 nn.AvgPool2d(kernel_size=2, stride=2), 506 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 507 nn.ReLU(), 508 nn.AvgPool2d(kernel_size=2, stride=2), 509 nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 510 nn.ReLU(), 511 nn.AvgPool2d(kernel_size=2, stride=2), 512 nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 513 nn.ReLU(), 514 nn.AdaptiveAvgPool2d((1, 1)), 515 nn.Flatten(start_dim=1, end_dim=-1), 516 nn.Linear(128, num_classes, bias=True), 517 ) 518 519 atol, rtol = self._compute_tolerances(device) 520 return self._test_conv_model( 521 convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol 522 ) 523 524 @parametrize("num_dim", [1, 2, 3]) 525 @tf32_off() 526 def test_instance_norm_model(self, num_dim, device): 527 def instance_norm_model(num_classes, num_dim): 528 conv_layer = ( 529 nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d 530 ) 531 norm_layer = ( 532 nn.InstanceNorm1d 533 if num_dim == 1 534 else nn.InstanceNorm2d 535 if num_dim == 2 536 else nn.InstanceNorm3d 537 ) 538 return nn.Sequential( 539 conv_layer(3, 32, kernel_size=3, stride=1, padding=1), 540 norm_layer(32, affine=True), 541 nn.Flatten(start_dim=1, end_dim=-1), 542 nn.Linear(32 * (7**num_dim), num_classes, bias=True), 543 ) 544 545 atol, rtol = self._compute_tolerances(device) 546 return self._test_conv_model( 547 instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol 548 ) 549 550 @parametrize("num_dim", [1, 2, 3]) 551 @tf32_off() 552 def test_group_norm_model(self, num_dim, device): 553 def group_norm_model(num_classes, num_dim): 554 conv_layer = ( 555 nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d 556 ) 557 return nn.Sequential( 558 conv_layer(3, 32, kernel_size=3, stride=1, padding=1), 559 nn.GroupNorm(8, 32, affine=True), 560 nn.Flatten(start_dim=1, end_dim=-1), 561 nn.Linear(32 * (7**num_dim), num_classes, bias=True), 562 ) 563 564 atol, rtol = self._compute_tolerances(device) 565 return self._test_conv_model( 566 group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol 567 ) 568 569 @parametrize("num_dim", [1, 2, 3]) 570 @tf32_off() 571 def test_layer_norm_model(self, num_dim, device): 572 def layer_norm_model(num_classes, num_dim): 573 conv_layer = ( 574 nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d 575 ) 576 normalized_shape = [7] * num_dim 577 return nn.Sequential( 578 conv_layer(3, 32, kernel_size=3, stride=1, padding=1), 579 nn.LayerNorm(normalized_shape, elementwise_affine=True), 580 nn.Flatten(start_dim=1, end_dim=-1), 581 nn.Linear(32 * (7**num_dim), num_classes, bias=True), 582 ) 583 584 atol, rtol = self._compute_tolerances(device) 585 return self._test_conv_model( 586 layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol 587 ) 588 589 def test_embedding_model(self, device): 590 def embedding_model(num_classes, num_embedding): 591 return nn.Sequential( 592 nn.Embedding(num_embedding, 15), 593 nn.Flatten(start_dim=1, end_dim=-1), 594 nn.Linear(375, num_classes, bias=True), 595 ) 596 597 return self._test_embedding_model(embedding_model, 16, device) 598 599 def test_group_norm_error(self, device): 600 # group norm has to call native_group_norm. This checks that it hits the same errors 601 # that normal group norm would 602 603 N = 3 604 C = 5 605 inp = torch.randn(N, C) 606 with self.assertRaisesRegex( 607 RuntimeError, r"Expected number of channels in input to be divisible" 608 ): 609 F.group_norm(inp, 2) # 5 is not divisible by 2 610 611 612class TestExpandedWeightModule(TestCase): 613 def _do_test( 614 self, 615 module, 616 input, 617 args=None, 618 kwargs=None, 619 batch_first=True, 620 atol=None, 621 rtol=None, 622 ): 623 args = args or () 624 kwargs = kwargs or {} 625 626 batch_dim = 0 if batch_first else 1 627 batch_size = input.shape[batch_dim] 628 diff_input = input.dtype == torch.float or input.dtype == torch.double 629 if diff_input: 630 input.requires_grad_() 631 632 with freeze_rng_state(): 633 # get per sample grads with ExpandedWeights context manager 634 actual_res = call_for_per_sample_grads( 635 module, 636 batch_size=batch_size, 637 loss_reduction="sum", 638 batch_first=batch_first, 639 )(input, *args, **kwargs).sum() 640 actual_res.backward() 641 actual_grads = [] 642 for param in module.parameters(): 643 actual_grads.append(param.grad_sample) 644 del param.grad_sample 645 if diff_input: 646 actual_grads.append(input.grad.clone()) 647 input.grad = torch.zeros_like(input.grad) 648 649 # get per sample grads with a for loop 650 expected_res = torch.tensor( 651 0.0, device=input.device, dtype=actual_res.dtype 652 ) 653 expected_grads = [] 654 for i in range(batch_size): 655 input_slice = input.narrow(batch_dim, i, 1) 656 input_slice = input_slice.squeeze(batch_dim) 657 658 # h's batch dim is always the first dim. Must be contiguous for CUDA 659 sliced_args = tree_map_only( 660 torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args 661 ) 662 diff_params = module.parameters() 663 if diff_input: 664 diff_params = chain(diff_params, (input_slice,)) 665 res = module( 666 input_slice.unsqueeze(batch_dim).contiguous(), 667 *sliced_args, 668 **kwargs, 669 ).sum() 670 out_grads = torch.autograd.grad( 671 res, diff_params, torch.ones_like(res), allow_unused=True 672 ) 673 expected_grads.append(out_grads) 674 expected_res += res 675 expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] 676 if not batch_first: 677 expected_grads[-1] = expected_grads[-1].transpose(0, 1) 678 self.assertEqual(actual_res, expected_res) 679 [ 680 self.assertEqual(actual, expected, atol=atol, rtol=rtol) 681 for (actual, expected) in zip(actual_grads, expected_grads) 682 ] 683 684 def _do_test_multi_input(self, module, input): 685 class TestModule(nn.Module): 686 def __init__(self, module): 687 super().__init__() 688 self.module = module 689 690 def forward(self, input): 691 return self.module(input) + self.module(input) 692 693 batch_size = input.shape[0] 694 diff_input = input.dtype == torch.float or input.dtype == torch.double 695 if diff_input: 696 input.requires_grad_() 697 with freeze_rng_state(): 698 # get per sample grads with ExpandedWeights context manager, calling .backward() twice 699 test_module = TestModule(module) 700 actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")( 701 input 702 ).sum() 703 actual_res.backward() 704 actual_grads = [] 705 for param in module.parameters(): 706 actual_grads.append(param.grad_sample) 707 del param.grad_sample 708 if diff_input: 709 actual_grads.append(input.grad.clone()) 710 input.grad = torch.zeros_like(input.grad) 711 712 # get per sample grads with a for loop, running over the input twice 713 expected_grads = [] 714 for i in range(batch_size): 715 input_slice = input[i] 716 diff_params = module.parameters() 717 if diff_input: 718 diff_params = chain(diff_params, (input_slice,)) 719 res = module(input_slice.unsqueeze(0)).sum() 720 out_grads = torch.autograd.grad( 721 res, diff_params, torch.ones_like(res), allow_unused=True 722 ) 723 expected_grads.append(out_grads) 724 expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads)) 725 expected_grads = tuple( 726 expected_grad 727 for expected_grad in expected_grads 728 if expected_grad is not None 729 ) 730 assert [ 731 self.assertEqual(actual, 2 * expected) 732 for (actual, expected) in zip(actual_grads, expected_grads) 733 ] 734 735 def _do_test_rnn_packed_sequence( 736 self, module, input, args=None, kwargs=None, atol=None, rtol=None 737 ): 738 args = args if args is not None else () 739 kwargs = kwargs if kwargs is not None else {} 740 741 batch_size = max(tuple(input.batch_sizes)).item() 742 743 with freeze_rng_state(): 744 # get per sample grads with ExpandedWeights context manager 745 actual_res = call_for_per_sample_grads( 746 module, batch_size=batch_size, loss_reduction="sum" 747 )(input, *args, **kwargs).data.sum() 748 actual_res.backward() 749 actual_grads = [] 750 for param in module.parameters(): 751 self.assertEqual(param.grad_sample.shape[0], batch_size) 752 actual_grads.append(param.grad_sample) 753 del param.grad_sample 754 755 input.data.grad = torch.zeros_like(input.data) 756 757 # compute the per sample grads with a for loop 758 expected_res = torch.zeros_like(actual_res) 759 expected_grads = [] 760 padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence( 761 input, batch_first=True 762 ) 763 for i in range(len(seq_sizes)): 764 input_slice = padded_input[i].narrow(0, 0, seq_sizes[i]) 765 diff_params = module.parameters() 766 batch_dim = 0 if module.m.batch_first else 1 767 res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum() 768 expected_res += res 769 out_grads = torch.autograd.grad( 770 res, diff_params, torch.ones_like(res), allow_unused=True 771 ) 772 expected_grads.append(out_grads) 773 774 expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] 775 self.assertEqual(actual_res, expected_res) 776 [ 777 self.assertEqual(actual, expected, atol=atol, rtol=rtol) 778 for (actual, expected) in zip(actual_grads, expected_grads) 779 ] 780 781 @modules( 782 filter( 783 lambda m_info: m_info.module_cls 784 in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), 785 module_db, 786 ) 787 ) 788 @tf32_off() 789 def test_module(self, device, dtype, module_info, training): 790 class RNNWrapper(torch.nn.Module): 791 def __init__(self, m_cons, args, kwargs): 792 super().__init__() 793 self.m = m_cons(*args, **kwargs) 794 795 def forward(self, *inps): 796 ret = self.m(*inps) 797 assert isinstance(ret, tuple) 798 return ret[0] 799 800 def batch_hidden(h): 801 new_h_shape = [1] * (len(h.shape) + 1) 802 new_h_shape[1] = 2 803 return h.unsqueeze(1).repeat(new_h_shape) 804 805 module_cls = module_info.module_cls 806 atol, rtol = ( 807 (1e-4, 1e-5) 808 if module_cls == torch.nn.GRU and dtype == torch.float32 809 else (None, None) 810 ) 811 module_inputs = module_info.module_inputs_func( 812 module_info, 813 device=device, 814 dtype=dtype, 815 requires_grad=True, 816 training=training, 817 with_packed_sequence=True, 818 ) 819 for module_input in module_inputs: 820 if module_input.forward_input is None: 821 continue 822 args, kwargs = ( 823 module_input.constructor_input.args, 824 module_input.constructor_input.kwargs, 825 ) 826 m = RNNWrapper(module_cls, args, kwargs) 827 batch_first = m.m.batch_first 828 m.to(device).to(dtype) 829 830 args, kwargs = ( 831 module_input.forward_input.args, 832 module_input.forward_input.kwargs, 833 ) 834 835 # if the RNN tests use unbatched inputs--batch the inputs 836 input = args[0] 837 if isinstance(input, torch.Tensor) and input.dim() == 2: 838 input = input.detach() 839 new_input_shape = [1] * (len(input.shape) + 1) 840 if batch_first: 841 new_input_shape[0] = 2 842 input = input.repeat(new_input_shape) 843 else: 844 new_input_shape[1] = 2 845 input = input.unsqueeze(1).repeat(new_input_shape) 846 847 h = args[1] if len(args) > 1 else None 848 if h is not None: 849 h = ( 850 batch_hidden(h) 851 if isinstance(h, torch.Tensor) 852 else tuple(batch_hidden(hx) for hx in h) 853 ) 854 args = list(args) 855 args[1] = h 856 857 if isinstance(input, torch.nn.utils.rnn.PackedSequence): 858 self._do_test_rnn_packed_sequence( 859 m, input, args[1:], kwargs, atol=atol, rtol=rtol 860 ) 861 else: 862 self._do_test( 863 m, 864 input, 865 args[1:], 866 kwargs, 867 batch_first=batch_first, 868 atol=atol, 869 rtol=rtol, 870 ) 871 872 def test_per_sample_api_failing(self): 873 module = nn.Linear(10, 10) 874 input = torch.randn(64, 10) 875 with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): 876 call_for_per_sample_grads("fail")(input) 877 with self.assertRaisesRegex( 878 RuntimeError, r"Batch size passed must be None or an integer" 879 ): 880 call_for_per_sample_grads(module, batch_size=6.4)(input) 881 with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): 882 call_for_per_sample_grads(module, batch_size=-64)(input) 883 with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): 884 loss = call_for_per_sample_grads(module)(input).sum() 885 loss.backward() # populate grad_sample fields 886 call_for_per_sample_grads(module)(input) 887 888 module = nn.Linear(10, 10) # reset to not have grad_sample fields 889 with self.assertRaisesRegex( 890 RuntimeError, r"Expected loss_reduction argument to be sum or mean" 891 ): 892 call_for_per_sample_grads(module, loss_reduction="")(input) 893 894 def test_per_sample_api_compute_batch_size(self): 895 class CustomModule(nn.Module): 896 def __init__(self) -> None: 897 super().__init__() 898 self.linear = nn.Linear(5, 5) 899 900 def forward(self, input1, input2): 901 return self.linear(input1) + self.linear(input2) 902 903 module = CustomModule() 904 input1 = torch.randn(4, 5) 905 input2 = torch.randn(5, 5) 906 907 with self.assertRaisesRegex( 908 RuntimeError, 909 "found at least one input with batch size 4 and one with batch size 5", 910 ): 911 call_for_per_sample_grads(module)(input1, input2) 912 913 input2 = torch.randn(4, 5) 914 call_for_per_sample_grads(module)(input1, input2) 915 916 module = CustomModule() 917 call_for_per_sample_grads(module)(input1, input2=input2) 918 919 module = CustomModule() 920 call_for_per_sample_grads(module)(input1=input1, input2=input2) 921 922 def test_per_sample_api_compute_batch_size_not_pytreeable(self): 923 @dataclass 924 class NonPytreeableTuple: 925 elem1: torch.Tensor 926 elem2: torch.Tensor 927 928 class CustomModule(nn.Module): 929 def __init__(self) -> None: 930 super().__init__() 931 self.linear = nn.Linear(5, 5) 932 933 def forward(self, input1, input2): 934 return self.linear(input1.elem1) + self.linear(input1.elem2) 935 936 input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5)) 937 model = CustomModule() 938 with self.assertRaisesRegex( 939 RuntimeError, 940 "ExpandedWeights cannot compute the batch size from the inputs", 941 ): 942 call_for_per_sample_grads(model)(input, "") 943 944 # would prefer for it to error because input is not pytree-able but that's hard to detect 945 with self.assertRaisesRegex( 946 RuntimeError, "Expected ExpandedWeights to have batch size matching input" 947 ): 948 call_for_per_sample_grads(model)(input, torch.randn(5)) 949 950 model = CustomModule() # TODO: functional call bug, sam will fix 951 call_for_per_sample_grads(model)(input, torch.randn(4, 5)) 952 model = CustomModule() 953 call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5)) 954 955 956class ContextManagerTests(TestBase): 957 def __init__(self, *args, **kwargs): 958 self.test_cpu = kwargs.get("test_cpu", True) 959 self.test_cuda = kwargs.get("test_cuda", True) 960 super().__init__(*args, **kwargs) 961 962 @property 963 def constructor_args(self): 964 return self._get_arg("constructor_args", False) 965 966 def test_context_manager(self, test_case, device): 967 kwargs = {"device": device, "dtype": torch.double} 968 module = self.constructor(*self.constructor_args).to(**kwargs) 969 if "Embedding" in self.get_name(): 970 kwargs["dtype"] = torch.long 971 input = self._get_input().to(**kwargs) 972 if len(input.shape) == 0 or input.shape[0] == 0: 973 raise unittest.SkipTest( 974 "Can't get per sample gradients when no batch dim or batch dim is 0" 975 ) 976 if self.constructor == torch.nn.Linear and len(input.shape) == 1: 977 raise unittest.SkipTest( 978 "Can't get per sample gradients for input of rank 1" 979 ) 980 test_case._do_test(module, input) 981 982 def test_context_manager_multiple_inputs(self, test_case, device): 983 module = self.constructor(*self.constructor_args).to(device) 984 input = self._get_input() 985 if len(input.shape) == 0 or input.shape[0] == 0: 986 raise unittest.SkipTest( 987 "Can't get per sample gradients when no batch dim or batch dim is 0" 988 ) 989 if self.constructor == torch.nn.Linear and len(input.shape) == 1: 990 raise unittest.SkipTest( 991 "Can't get per sample gradients for input of rank 1" 992 ) 993 test_case._do_test_multi_input(module, input) 994 995 996def filter_supported_tests(t): 997 supported_modules = [ 998 "Linear", 999 "Conv1d", 1000 "Conv2d", 1001 "Conv3d", 1002 "Embedding", 1003 "LayerNorm", 1004 "GroupNorm", 1005 "InstanceNorm", 1006 ] 1007 if "module_name" in t and t["module_name"] in supported_modules: 1008 return True 1009 1010 1011# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests 1012# These currently use the legacy nn tests 1013supported_tests = [ 1014 t for t in module_tests + new_module_tests if filter_supported_tests(t) 1015] 1016for test_param in supported_tests: 1017 if "constructor" not in test_param: 1018 name = test_param.pop("module_name") 1019 test_param["constructor"] = getattr(nn, name) 1020 decorator = test_param.pop("decorator", lambda test: test) 1021 test = ContextManagerTests(**test_param) 1022 test_name = test.get_name() 1023 if hasattr(TestExpandedWeightModule, test_name): 1024 raise RuntimeError("Found two tests with the same name: " + test_name) 1025 test_name_multi_input = test.get_name() + "_multiple_inputs" 1026 if hasattr(TestExpandedWeightModule, test_name_multi_input): 1027 raise RuntimeError("Found two tests with the same name: " + test_name) 1028 if test.test_cpu: 1029 setattr( 1030 TestExpandedWeightModule, 1031 test_name, 1032 decorator(lambda self, test=test: test.test_context_manager(self, "cpu")), 1033 ) 1034 setattr( 1035 TestExpandedWeightModule, 1036 test_name_multi_input, 1037 decorator( 1038 lambda self, test=test: test.test_context_manager_multiple_inputs( 1039 self, "cpu" 1040 ) 1041 ), 1042 ) 1043 if TEST_CUDA and test.test_cuda: 1044 # since this checks derivatives, only use double for precision 1045 setattr( 1046 TestExpandedWeightModule, 1047 test_name + "_cuda_double", 1048 decorator(lambda self, test=test: test.test_context_manager(self, "cuda")), 1049 ) 1050 1051# ------------- HELPER FUNCTIONS ----------------- 1052 1053 1054def run_op(op, input, *args, **kwargs): 1055 r""" 1056 OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative 1057 of the weight, not the input, which can't be differentiable since its dtype is int. Calls op, 1058 using the special ordering that Embedding's OpInfo expects for that case. 1059 """ 1060 if op.name == "nn.functional.embedding": 1061 return op(args[0], input, **kwargs) 1062 else: 1063 return op(input, *args, **kwargs) 1064 1065 1066def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"): 1067 def expanded_weight_or_clone(arg): 1068 if is_diff_tensor(arg): 1069 return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction) 1070 return clone_if_tensor(arg) 1071 1072 ew_input = clone_if_tensor(sample_input.input) 1073 ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args) 1074 ew_kwargs = { 1075 name: expanded_weight_or_clone(arg) 1076 for (name, arg) in sample_input.kwargs.items() 1077 } 1078 return ew_input, ew_args, ew_kwargs 1079 1080 1081def supported_inputs(op, sample_inputs, supported_inputs=True): 1082 r""" 1083 ExpandedWeights currently does not support some use cases when there's no batch dimension or 1084 operations that would cause inter-batch operations. Removes all of the cases it cannot deal with 1085 """ 1086 1087 def filter_fn(input): 1088 convolutions = [ 1089 "nn.functional.conv1d", 1090 "nn.functional.conv2d", 1091 "nn.functional.conv3d", 1092 ] 1093 batched_input_size = dict(zip(convolutions, [3, 4, 5])) 1094 if op.name == "nn.functional.linear": 1095 is_supported_input = ( 1096 input.input.dim() > 1 1097 ) # input of rank 1 means no batch dim 1098 elif op.name == "nn.functional.layer_norm": 1099 normalized_shape = input.args[0] 1100 is_supported_input = ( 1101 input.input.shape != normalized_shape 1102 ) # would cause inter-batch operations 1103 elif op.name in convolutions: 1104 # currently can't deal with padding computation on Python level 1105 is_supported_input = input.input.dim() == batched_input_size[op.name] 1106 elif op.name == "nn.functional.embedding": 1107 idx = input.args[0] 1108 is_supported_input = len(idx.shape) > 1 # there's no batch size 1109 else: 1110 is_supported_input = True 1111 is_supported_input = ( 1112 is_supported_input and input.input.shape[0] > 0 1113 ) # 0 is not a valid batch size 1114 return is_supported_input if supported_inputs else not is_supported_input 1115 1116 return [input for input in sample_inputs if filter_fn(input)] 1117 1118 1119def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs): 1120 # get per sample grads by getting derivative for each input in a for loop 1121 per_sample_grad = [] 1122 for i in range(batch_size): 1123 per_sample_input = input[i] 1124 result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs)) 1125 diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values()) 1126 diff_input_list = [ 1127 i 1128 for i in diff_input_list 1129 if isinstance(i, torch.Tensor) and i.requires_grad 1130 ] 1131 per_sample_grad.append( 1132 torch.autograd.grad( 1133 result, diff_input_list, torch.ones_like(result), allow_unused=True 1134 ) 1135 ) 1136 if len(per_sample_grad) == batch_size: 1137 per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad)) 1138 return per_sample_grad 1139 1140 1141def is_diff_tensor(t): 1142 return isinstance(t, ExpandedWeight) or ( 1143 isinstance(t, torch.Tensor) and t.requires_grad 1144 ) 1145 1146 1147def clone_if_tensor(t): 1148 if isinstance(t, torch.Tensor): 1149 res = torch.clone(t).detach() 1150 res.requires_grad_(t.requires_grad) 1151 return res 1152 else: 1153 return t 1154 1155 1156instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals()) 1157instantiate_device_type_tests(TestExpandedWeightFunctional, globals()) 1158instantiate_device_type_tests(TestExpandedWeightModule, globals()) 1159if __name__ == "__main__": 1160 run_tests() 1161