xref: /aosp_15_r20/external/pytorch/test/test_expanded_weights.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import unittest
3from dataclasses import dataclass
4from functools import partial
5from itertools import chain, product
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.nn import CrossEntropyLoss
11from torch.nn.utils._expanded_weights import ExpandedWeight
12from torch.nn.utils._expanded_weights.expanded_weights_utils import (
13    forward_helper,
14    set_grad_sample_if_exists,
15    standard_kwargs,
16    sum_over_all_but_batch_and_last_n,
17    unpack_expanded_weight_or_tensor,
18)
19from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
20from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
21from torch.testing._internal.common_device_type import (
22    instantiate_device_type_tests,
23    OpDTypes,
24    ops,
25)
26from torch.testing._internal.common_methods_invocations import op_db, SampleInput
27from torch.testing._internal.common_modules import module_db, modules
28from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
29from torch.testing._internal.common_utils import (
30    freeze_rng_state,
31    make_tensor,
32    parametrize,
33    run_tests,
34    skipIfTorchDynamo,
35    TestCase,
36)
37from torch.utils._pytree import tree_map_only
38
39
40class TestContext:
41    pass
42
43
44class TestExpandedWeightHelperFunction(TestCase):
45    def test_forward_helper(self, device):
46        input = torch.randn(3, 4, device=device)
47        weight = torch.randn(5, 4, device=device)
48        bias = torch.randn(5, device=device)
49        for weight_batched, bias_batched in product([True, False], [True, False]):
50            maybe_batched_weight = weight
51            maybe_batched_bias = bias
52            if weight_batched:
53                maybe_batched_weight = ExpandedWeight(
54                    weight.clone().requires_grad_(), 3, loss_reduction="sum"
55                )
56            if bias_batched:
57                maybe_batched_bias = ExpandedWeight(
58                    bias.clone().requires_grad_(), 3, loss_reduction="sum"
59                )
60            args = (input, maybe_batched_weight, maybe_batched_bias)
61            expanded_args, expanded_kwargs = standard_kwargs(("bias",), args)
62            res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
63            expected = nn.functional.linear(input, weight, bias)
64            self.assertEqual(res, expected)
65
66            self.assertEqual(len(expanded_args), 2)
67            assert expanded_args[0] is args[0]  # avoids property checks in assertEquals
68            assert expanded_args[1] is args[1]  # avoids property checks in assertEquals
69            self.assertEqual(len(expanded_kwargs), 1)
70            assert (
71                expanded_kwargs["bias"] is args[2]
72            )  # avoids property checks in assertEquals
73
74    def test_forward_helper_failure_args(self, device):
75        weight = torch.randn(5, 4, device=device)
76        bias = torch.randn(5, device=device)
77        with self.assertRaisesRegex(
78            RuntimeError, r"do not support inputs that are also ExpandedWeights."
79        ):
80            input = ExpandedWeight(
81                torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum"
82            )
83            expanded_args, expanded_kwargs = standard_kwargs(
84                ("bias",), (input, weight, bias)
85            )
86            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
87        with self.assertRaisesRegex(
88            RuntimeError, r"requires a Tensor as the first input"
89        ):
90            expanded_args, expanded_kwargs = standard_kwargs(
91                ("bias",), (3, weight, bias)
92            )
93            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
94        with self.assertRaisesRegex(
95            RuntimeError, r"requires a batch dimension but got an input of size 0"
96        ):
97            expanded_args, expanded_kwargs = standard_kwargs(
98                ("bias",), (torch.tensor(3), weight, bias)
99            )
100            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
101        with self.assertRaisesRegex(
102            RuntimeError, r"0 is not a valid batch size for Expanded Weights"
103        ):
104            expanded_args, expanded_kwargs = standard_kwargs(
105                ("bias",), (torch.randn(0, 1, 2), weight, bias)
106            )
107            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
108        input = torch.randn(3, 4)
109        for weight_batched, bias_batched in product([True, False], [True, False]):
110            if not weight_batched and not bias_batched:
111                continue
112            maybe_batched_weight = weight
113            maybe_batched_bias = bias
114            if weight_batched:
115                maybe_batched_weight = ExpandedWeight(
116                    weight.clone().requires_grad_(), 4, loss_reduction="sum"
117                )
118            if bias_batched:
119                maybe_batched_bias = ExpandedWeight(
120                    bias.clone().requires_grad_(), 4, loss_reduction="sum"
121                )
122            with self.assertRaisesRegex(
123                RuntimeError,
124                r"Expected ExpandedWeights to have batch size matching input",
125            ):
126                expanded_args, expanded_kwargs = standard_kwargs(
127                    ("bias",), (input, maybe_batched_weight, maybe_batched_bias)
128                )
129                forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
130
131    def test_set_grad_sample_if_exists(self, device):
132        def test_fn(a):
133            return grad_sample
134
135        orig_weight = torch.randn(4, device=device, requires_grad=True)
136        expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
137        grad_sample = torch.randn(3)
138        set_grad_sample_if_exists(expanded_weight, test_fn)
139        self.assertTrue(hasattr(orig_weight, "grad_sample"))
140        self.assertEqual(orig_weight.grad_sample, grad_sample)
141
142        basic_tensor = torch.randn(4, device=device)
143        set_grad_sample_if_exists(basic_tensor, test_fn)
144        self.assertFalse(hasattr(basic_tensor, "grad_sample"))
145
146        non_tensor = 3
147        set_grad_sample_if_exists(non_tensor, test_fn)
148        self.assertFalse(hasattr(non_tensor, "grad_sample"))
149
150    def test_set_grad_sample_if_exists_failure(self, device):
151        def test_fn(a):
152            return True
153
154        grad_tensor = torch.randn(4, requires_grad=True, device=device)
155        with self.assertRaisesRegex(
156            RuntimeError,
157            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
158        ):
159            set_grad_sample_if_exists(grad_tensor, test_fn)
160
161    def test_unpack_expanded_weight_or_tensor(self, device):
162        input = torch.randn(3, requires_grad=True, device=device)
163        self.assertEqual(
164            input,
165            unpack_expanded_weight_or_tensor(
166                ExpandedWeight(input, 3, loss_reduction="sum")
167            ),
168        )
169
170        input.requires_grad_(False)
171        self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
172        self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
173
174    def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
175        input = torch.randn(3, requires_grad=True, device=device)
176        self.assertTrue(
177            unpack_expanded_weight_or_tensor(
178                ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input
179            )
180        )
181
182        input.requires_grad_(False)
183        self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
184        self.assertTrue(
185            unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None
186        )
187
188    def test_unpack_expanded_weight_or_tensor_failure(self, device):
189        input = torch.randn(3, requires_grad=True, device=device)
190        with self.assertRaisesRegex(
191            RuntimeError,
192            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
193        ):
194            unpack_expanded_weight_or_tensor(input)
195
196        with self.assertRaisesRegex(
197            RuntimeError,
198            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
199        ):
200            unpack_expanded_weight_or_tensor(input, lambda x: x is input)
201
202    def test_sum_over_all_but_batch_and_last_n(self, device):
203        input = torch.randn(1, 2, 3, 4, 5, device=device)
204        res = sum_over_all_but_batch_and_last_n(input, 2)
205        expected = input.sum((1, 2))
206        self.assertEqual(res, expected)
207
208        res = sum_over_all_but_batch_and_last_n(input, 0)
209        expected = input.sum((1, 2, 3, 4))
210        self.assertEqual(res, expected)
211
212        res = sum_over_all_but_batch_and_last_n(input, 4)
213        self.assertEqual(res, input)
214
215
216class TestExpandedWeightFunctional(TestCase):
217    def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
218        input = sample_input.input
219        args = sample_input.args
220        kwargs = sample_input.kwargs
221        batch_size = input.shape[0] if len(input.shape) > 1 else 1
222
223        # get per sample grads with ExpandedWeights objects
224        loss_reduction = "sum" if reduction == torch.sum else "mean"
225        (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
226            sample_input, batch_size, loss_reduction
227        )
228        diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
229        diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
230        diff_input_list = [
231            i.orig_weight if isinstance(i, ExpandedWeight) else i
232            for i in diff_input_list
233        ]
234        if not diff_input_list:
235            return
236        result = run_op(op, ew_input, *ew_args, **ew_kwargs)
237        reduction(
238            result
239        ).backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
240        expanded_weight_grad = tuple(
241            i.grad_sample if hasattr(i, "grad_sample") else i.grad
242            for i in diff_input_list
243        )
244
245        # get per sample grads with for loop
246        func = partial(run_op, op)
247
248        per_sample_grad = for_loop_per_sample_grad(
249            batch_size, reduction, input, func, *args, **kwargs
250        )
251
252        # check equality
253        self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
254        if loss_reduction == "mean":
255            # don't check equality of `input.grad`s since these vanilla tensors won't be scaled
256            expanded_weight_grad = expanded_weight_grad[1:]
257            per_sample_grad = per_sample_grad[1:]
258        for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad):
259            self.assertEqual(result_grad, expected_grad)
260
261    @ops(
262        filter(lambda op: op.supports_expanded_weight, op_db),
263        dtypes=OpDTypes.supported,
264        allowed_dtypes=(torch.double,),
265    )
266    def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
267        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
268        for sample_input in supported_inputs(op, sample_inputs):
269            if (
270                op.name == "nn.functional.embedding"
271            ):  # embedding flips its argument order for autograd tests
272                sample_input = SampleInput(
273                    sample_input.args[0],
274                    args=(sample_input.input,),
275                    kwargs=sample_input.kwargs,
276                )
277
278            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
279
280    @ops(
281        filter(lambda op: op.supports_expanded_weight, op_db),
282        dtypes=OpDTypes.supported,
283        allowed_dtypes=(torch.double,),
284    )
285    def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
286        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
287        for sample_input in supported_inputs(op, sample_inputs):
288            if (
289                op.name == "nn.functional.embedding"
290            ):  # embedding flips its argument order for autograd tests
291                sample_input = SampleInput(
292                    sample_input.args[0],
293                    args=(sample_input.input,),
294                    kwargs=sample_input.kwargs,
295                )
296
297            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
298
299    @ops(
300        filter(lambda op: op.supports_expanded_weight, op_db),
301        dtypes=OpDTypes.supported,
302        allowed_dtypes=(torch.double,),
303    )
304    def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
305        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
306        for sample_input in supported_inputs(op, sample_inputs):
307            if (
308                op.name == "nn.functional.embedding"
309            ):  # embedding flips its argument order for autograd tests
310                sample_input = SampleInput(
311                    sample_input.args[0],
312                    args=(sample_input.input,),
313                    kwargs=sample_input.kwargs,
314                )
315            sample_input.input.requires_grad_(False)
316
317            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
318
319    @skipIfTorchDynamo("Checking error message doesn't work with dynamo")
320    @ops(
321        filter(lambda op: op.supports_expanded_weight, op_db),
322        dtypes=OpDTypes.supported,
323        allowed_dtypes=(torch.double,),
324    )
325    def test_unsupported_expand_weights(self, device, dtype, op):
326        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
327        unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False)
328        for sample_input in unsupported_inputs:
329            with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"):
330                if (
331                    op.name == "nn.functional.embedding"
332                ):  # embedding flips its argument order for autograd tests
333                    sample_input = SampleInput(
334                        sample_input.args[0],
335                        args=(sample_input.input,),
336                        kwargs=sample_input.kwargs,
337                    )
338                input = sample_input.input
339
340                batch_size = input.shape[0] if len(input.shape) > 1 else 1
341
342                # get per sample grads with ExpandedWeights objects
343                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
344                    sample_input, batch_size
345                )
346                result = run_op(op, ew_input, *ew_args, **ew_kwargs)
347                diff_input_list = (
348                    (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
349                )
350                diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
351                diff_input_list = [
352                    i.orig_weight if isinstance(i, ExpandedWeight) else i
353                    for i in diff_input_list
354                ]
355                result.sum().backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
356
357    @ops(
358        filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported
359    )
360    def test_expanded_weight_forward(self, device, dtype, op):
361        sample_inputs = op.sample_inputs(device, dtype)
362        for sample_input in supported_inputs(op, sample_inputs):
363            if (
364                op.name == "nn.functional.embedding"
365            ):  # embedding flips its argument order for autograd tests
366                sample_input = SampleInput(
367                    sample_input.args[0].clone(),
368                    args=(sample_input.input.clone(),),
369                    kwargs=sample_input.kwargs,
370                )
371                if (
372                    "cuda" in device
373                    and "max_norm" in sample_input.kwargs
374                    and "padding_idx" in sample_input.kwargs
375                ):
376                    self.skipTest(
377                        "embedding is non-determinstic in this case, see issue #74679"
378                    )
379            batch_size = (
380                sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
381            )
382            for loss_reduction in ["sum", "mean"]:
383                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
384                    sample_input, batch_size, loss_reduction
385                )
386                expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
387                normal_result = run_op(
388                    op, sample_input.input, *sample_input.args, **sample_input.kwargs
389                )
390                self.assertEqual(expanded_weight_result, normal_result)
391
392    def test_expanded_weight_error(self, device):
393        batch_size = 3
394        sample_input = make_tensor(
395            (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True
396        )
397        sample_weight = make_tensor(
398            (4), dtype=torch.float32, device=device, requires_grad=True
399        )
400        with self.assertRaisesRegex(
401            RuntimeError, r"Expanded Weights encountered but cannot handle function"
402        ):
403            torch.add(
404                sample_input,
405                ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"),
406            )
407
408    def _test_embedding_model(self, model, num_embedding, device):
409        batch_size = 32
410        input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
411        return self._test_model(
412            partial(model, num_embedding=num_embedding), batch_size, input, device
413        )
414
415    def _test_conv_model(
416        self,
417        model,
418        input_size,
419        num_dim,
420        device,
421        loss_reduction="sum",
422        atol=1e-4,
423        rtol=5e-5,
424    ):
425        batch_size = 32
426        input_ending = [input_size] * num_dim
427        input = torch.randn([batch_size, 3] + input_ending, device=device)
428        return self._test_model(
429            partial(model, num_dim=num_dim),
430            batch_size,
431            input,
432            device,
433            loss_reduction,
434            atol,
435            rtol,
436        )
437
438    def _test_model(
439        self,
440        model,
441        batch_size,
442        input,
443        device,
444        loss_reduction="sum",
445        atol=1e-4,
446        rtol=5e-5,
447    ):
448        model = model(10).to(device)
449        targets = torch.randint(0, 10, (batch_size,), device=device)
450        criterion = CrossEntropyLoss(reduction=loss_reduction)
451        result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input)
452        loss = criterion(result, targets)
453        loss.backward()
454        result = []
455        for weight in model.parameters():
456            result.append(weight.grad_sample)
457            del weight.grad_sample
458
459        expected = []
460        for i in range(batch_size):
461            loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0))
462            expected.append(
463                torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))
464            )
465
466        expected = [torch.stack(grad) for grad in zip(*expected)]
467        for res, exp in zip(result, expected):
468            self.assertEqual(res, exp, atol=atol, rtol=rtol)
469
470    def _compute_tolerances(self, device):
471        is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(
472            0
473        ) == (8, 6)
474        return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5)
475
476    @tf32_off()
477    def test_cnn_model_sum(self, device):
478        def convnet(num_classes, num_dim):
479            return nn.Sequential(
480                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
481                nn.ReLU(),
482                nn.AvgPool2d(kernel_size=2, stride=2),
483                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
484                nn.ReLU(),
485                nn.AvgPool2d(kernel_size=2, stride=2),
486                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
487                nn.ReLU(),
488                nn.AvgPool2d(kernel_size=2, stride=2),
489                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
490                nn.ReLU(),
491                nn.AdaptiveAvgPool2d((1, 1)),
492                nn.Flatten(start_dim=1, end_dim=-1),
493                nn.Linear(128, num_classes, bias=True),
494            )
495
496        atol, rtol = self._compute_tolerances(device)
497        return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
498
499    @tf32_off()
500    def test_cnn_model_mean(self, device):
501        def convnet(num_classes, num_dim):
502            return nn.Sequential(
503                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
504                nn.ReLU(),
505                nn.AvgPool2d(kernel_size=2, stride=2),
506                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
507                nn.ReLU(),
508                nn.AvgPool2d(kernel_size=2, stride=2),
509                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
510                nn.ReLU(),
511                nn.AvgPool2d(kernel_size=2, stride=2),
512                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
513                nn.ReLU(),
514                nn.AdaptiveAvgPool2d((1, 1)),
515                nn.Flatten(start_dim=1, end_dim=-1),
516                nn.Linear(128, num_classes, bias=True),
517            )
518
519        atol, rtol = self._compute_tolerances(device)
520        return self._test_conv_model(
521            convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol
522        )
523
524    @parametrize("num_dim", [1, 2, 3])
525    @tf32_off()
526    def test_instance_norm_model(self, num_dim, device):
527        def instance_norm_model(num_classes, num_dim):
528            conv_layer = (
529                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
530            )
531            norm_layer = (
532                nn.InstanceNorm1d
533                if num_dim == 1
534                else nn.InstanceNorm2d
535                if num_dim == 2
536                else nn.InstanceNorm3d
537            )
538            return nn.Sequential(
539                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
540                norm_layer(32, affine=True),
541                nn.Flatten(start_dim=1, end_dim=-1),
542                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
543            )
544
545        atol, rtol = self._compute_tolerances(device)
546        return self._test_conv_model(
547            instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
548        )
549
550    @parametrize("num_dim", [1, 2, 3])
551    @tf32_off()
552    def test_group_norm_model(self, num_dim, device):
553        def group_norm_model(num_classes, num_dim):
554            conv_layer = (
555                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
556            )
557            return nn.Sequential(
558                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
559                nn.GroupNorm(8, 32, affine=True),
560                nn.Flatten(start_dim=1, end_dim=-1),
561                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
562            )
563
564        atol, rtol = self._compute_tolerances(device)
565        return self._test_conv_model(
566            group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
567        )
568
569    @parametrize("num_dim", [1, 2, 3])
570    @tf32_off()
571    def test_layer_norm_model(self, num_dim, device):
572        def layer_norm_model(num_classes, num_dim):
573            conv_layer = (
574                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
575            )
576            normalized_shape = [7] * num_dim
577            return nn.Sequential(
578                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
579                nn.LayerNorm(normalized_shape, elementwise_affine=True),
580                nn.Flatten(start_dim=1, end_dim=-1),
581                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
582            )
583
584        atol, rtol = self._compute_tolerances(device)
585        return self._test_conv_model(
586            layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
587        )
588
589    def test_embedding_model(self, device):
590        def embedding_model(num_classes, num_embedding):
591            return nn.Sequential(
592                nn.Embedding(num_embedding, 15),
593                nn.Flatten(start_dim=1, end_dim=-1),
594                nn.Linear(375, num_classes, bias=True),
595            )
596
597        return self._test_embedding_model(embedding_model, 16, device)
598
599    def test_group_norm_error(self, device):
600        # group norm has to call native_group_norm. This checks that it hits the same errors
601        # that normal group norm would
602
603        N = 3
604        C = 5
605        inp = torch.randn(N, C)
606        with self.assertRaisesRegex(
607            RuntimeError, r"Expected number of channels in input to be divisible"
608        ):
609            F.group_norm(inp, 2)  # 5 is not divisible by 2
610
611
612class TestExpandedWeightModule(TestCase):
613    def _do_test(
614        self,
615        module,
616        input,
617        args=None,
618        kwargs=None,
619        batch_first=True,
620        atol=None,
621        rtol=None,
622    ):
623        args = args or ()
624        kwargs = kwargs or {}
625
626        batch_dim = 0 if batch_first else 1
627        batch_size = input.shape[batch_dim]
628        diff_input = input.dtype == torch.float or input.dtype == torch.double
629        if diff_input:
630            input.requires_grad_()
631
632        with freeze_rng_state():
633            # get per sample grads with ExpandedWeights context manager
634            actual_res = call_for_per_sample_grads(
635                module,
636                batch_size=batch_size,
637                loss_reduction="sum",
638                batch_first=batch_first,
639            )(input, *args, **kwargs).sum()
640            actual_res.backward()
641            actual_grads = []
642            for param in module.parameters():
643                actual_grads.append(param.grad_sample)
644                del param.grad_sample
645            if diff_input:
646                actual_grads.append(input.grad.clone())
647                input.grad = torch.zeros_like(input.grad)
648
649            # get per sample grads with a for loop
650            expected_res = torch.tensor(
651                0.0, device=input.device, dtype=actual_res.dtype
652            )
653            expected_grads = []
654            for i in range(batch_size):
655                input_slice = input.narrow(batch_dim, i, 1)
656                input_slice = input_slice.squeeze(batch_dim)
657
658                # h's batch dim is always the first dim. Must be contiguous for CUDA
659                sliced_args = tree_map_only(
660                    torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args
661                )
662                diff_params = module.parameters()
663                if diff_input:
664                    diff_params = chain(diff_params, (input_slice,))
665                res = module(
666                    input_slice.unsqueeze(batch_dim).contiguous(),
667                    *sliced_args,
668                    **kwargs,
669                ).sum()
670                out_grads = torch.autograd.grad(
671                    res, diff_params, torch.ones_like(res), allow_unused=True
672                )
673                expected_grads.append(out_grads)
674                expected_res += res
675            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
676            if not batch_first:
677                expected_grads[-1] = expected_grads[-1].transpose(0, 1)
678        self.assertEqual(actual_res, expected_res)
679        [
680            self.assertEqual(actual, expected, atol=atol, rtol=rtol)
681            for (actual, expected) in zip(actual_grads, expected_grads)
682        ]
683
684    def _do_test_multi_input(self, module, input):
685        class TestModule(nn.Module):
686            def __init__(self, module):
687                super().__init__()
688                self.module = module
689
690            def forward(self, input):
691                return self.module(input) + self.module(input)
692
693        batch_size = input.shape[0]
694        diff_input = input.dtype == torch.float or input.dtype == torch.double
695        if diff_input:
696            input.requires_grad_()
697        with freeze_rng_state():
698            # get per sample grads with ExpandedWeights context manager, calling .backward() twice
699            test_module = TestModule(module)
700            actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(
701                input
702            ).sum()
703            actual_res.backward()
704            actual_grads = []
705            for param in module.parameters():
706                actual_grads.append(param.grad_sample)
707                del param.grad_sample
708            if diff_input:
709                actual_grads.append(input.grad.clone())
710                input.grad = torch.zeros_like(input.grad)
711
712            # get per sample grads with a for loop, running over the input twice
713            expected_grads = []
714            for i in range(batch_size):
715                input_slice = input[i]
716                diff_params = module.parameters()
717                if diff_input:
718                    diff_params = chain(diff_params, (input_slice,))
719                res = module(input_slice.unsqueeze(0)).sum()
720                out_grads = torch.autograd.grad(
721                    res, diff_params, torch.ones_like(res), allow_unused=True
722                )
723                expected_grads.append(out_grads)
724        expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
725        expected_grads = tuple(
726            expected_grad
727            for expected_grad in expected_grads
728            if expected_grad is not None
729        )
730        assert [
731            self.assertEqual(actual, 2 * expected)
732            for (actual, expected) in zip(actual_grads, expected_grads)
733        ]
734
735    def _do_test_rnn_packed_sequence(
736        self, module, input, args=None, kwargs=None, atol=None, rtol=None
737    ):
738        args = args if args is not None else ()
739        kwargs = kwargs if kwargs is not None else {}
740
741        batch_size = max(tuple(input.batch_sizes)).item()
742
743        with freeze_rng_state():
744            # get per sample grads with ExpandedWeights context manager
745            actual_res = call_for_per_sample_grads(
746                module, batch_size=batch_size, loss_reduction="sum"
747            )(input, *args, **kwargs).data.sum()
748            actual_res.backward()
749            actual_grads = []
750            for param in module.parameters():
751                self.assertEqual(param.grad_sample.shape[0], batch_size)
752                actual_grads.append(param.grad_sample)
753                del param.grad_sample
754
755            input.data.grad = torch.zeros_like(input.data)
756
757            # compute the per sample grads with a for loop
758            expected_res = torch.zeros_like(actual_res)
759            expected_grads = []
760            padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(
761                input, batch_first=True
762            )
763            for i in range(len(seq_sizes)):
764                input_slice = padded_input[i].narrow(0, 0, seq_sizes[i])
765                diff_params = module.parameters()
766                batch_dim = 0 if module.m.batch_first else 1
767                res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum()
768                expected_res += res
769                out_grads = torch.autograd.grad(
770                    res, diff_params, torch.ones_like(res), allow_unused=True
771                )
772                expected_grads.append(out_grads)
773
774            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
775            self.assertEqual(actual_res, expected_res)
776            [
777                self.assertEqual(actual, expected, atol=atol, rtol=rtol)
778                for (actual, expected) in zip(actual_grads, expected_grads)
779            ]
780
781    @modules(
782        filter(
783            lambda m_info: m_info.module_cls
784            in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
785            module_db,
786        )
787    )
788    @tf32_off()
789    def test_module(self, device, dtype, module_info, training):
790        class RNNWrapper(torch.nn.Module):
791            def __init__(self, m_cons, args, kwargs):
792                super().__init__()
793                self.m = m_cons(*args, **kwargs)
794
795            def forward(self, *inps):
796                ret = self.m(*inps)
797                assert isinstance(ret, tuple)
798                return ret[0]
799
800        def batch_hidden(h):
801            new_h_shape = [1] * (len(h.shape) + 1)
802            new_h_shape[1] = 2
803            return h.unsqueeze(1).repeat(new_h_shape)
804
805        module_cls = module_info.module_cls
806        atol, rtol = (
807            (1e-4, 1e-5)
808            if module_cls == torch.nn.GRU and dtype == torch.float32
809            else (None, None)
810        )
811        module_inputs = module_info.module_inputs_func(
812            module_info,
813            device=device,
814            dtype=dtype,
815            requires_grad=True,
816            training=training,
817            with_packed_sequence=True,
818        )
819        for module_input in module_inputs:
820            if module_input.forward_input is None:
821                continue
822            args, kwargs = (
823                module_input.constructor_input.args,
824                module_input.constructor_input.kwargs,
825            )
826            m = RNNWrapper(module_cls, args, kwargs)
827            batch_first = m.m.batch_first
828            m.to(device).to(dtype)
829
830            args, kwargs = (
831                module_input.forward_input.args,
832                module_input.forward_input.kwargs,
833            )
834
835            # if the RNN tests use unbatched inputs--batch the inputs
836            input = args[0]
837            if isinstance(input, torch.Tensor) and input.dim() == 2:
838                input = input.detach()
839                new_input_shape = [1] * (len(input.shape) + 1)
840                if batch_first:
841                    new_input_shape[0] = 2
842                    input = input.repeat(new_input_shape)
843                else:
844                    new_input_shape[1] = 2
845                    input = input.unsqueeze(1).repeat(new_input_shape)
846
847                h = args[1] if len(args) > 1 else None
848                if h is not None:
849                    h = (
850                        batch_hidden(h)
851                        if isinstance(h, torch.Tensor)
852                        else tuple(batch_hidden(hx) for hx in h)
853                    )
854                    args = list(args)
855                    args[1] = h
856
857            if isinstance(input, torch.nn.utils.rnn.PackedSequence):
858                self._do_test_rnn_packed_sequence(
859                    m, input, args[1:], kwargs, atol=atol, rtol=rtol
860                )
861            else:
862                self._do_test(
863                    m,
864                    input,
865                    args[1:],
866                    kwargs,
867                    batch_first=batch_first,
868                    atol=atol,
869                    rtol=rtol,
870                )
871
872    def test_per_sample_api_failing(self):
873        module = nn.Linear(10, 10)
874        input = torch.randn(64, 10)
875        with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
876            call_for_per_sample_grads("fail")(input)
877        with self.assertRaisesRegex(
878            RuntimeError, r"Batch size passed must be None or an integer"
879        ):
880            call_for_per_sample_grads(module, batch_size=6.4)(input)
881        with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
882            call_for_per_sample_grads(module, batch_size=-64)(input)
883        with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
884            loss = call_for_per_sample_grads(module)(input).sum()
885            loss.backward()  # populate grad_sample fields
886            call_for_per_sample_grads(module)(input)
887
888        module = nn.Linear(10, 10)  # reset to not have grad_sample fields
889        with self.assertRaisesRegex(
890            RuntimeError, r"Expected loss_reduction argument to be sum or mean"
891        ):
892            call_for_per_sample_grads(module, loss_reduction="")(input)
893
894    def test_per_sample_api_compute_batch_size(self):
895        class CustomModule(nn.Module):
896            def __init__(self) -> None:
897                super().__init__()
898                self.linear = nn.Linear(5, 5)
899
900            def forward(self, input1, input2):
901                return self.linear(input1) + self.linear(input2)
902
903        module = CustomModule()
904        input1 = torch.randn(4, 5)
905        input2 = torch.randn(5, 5)
906
907        with self.assertRaisesRegex(
908            RuntimeError,
909            "found at least one input with batch size 4 and one with batch size 5",
910        ):
911            call_for_per_sample_grads(module)(input1, input2)
912
913        input2 = torch.randn(4, 5)
914        call_for_per_sample_grads(module)(input1, input2)
915
916        module = CustomModule()
917        call_for_per_sample_grads(module)(input1, input2=input2)
918
919        module = CustomModule()
920        call_for_per_sample_grads(module)(input1=input1, input2=input2)
921
922    def test_per_sample_api_compute_batch_size_not_pytreeable(self):
923        @dataclass
924        class NonPytreeableTuple:
925            elem1: torch.Tensor
926            elem2: torch.Tensor
927
928        class CustomModule(nn.Module):
929            def __init__(self) -> None:
930                super().__init__()
931                self.linear = nn.Linear(5, 5)
932
933            def forward(self, input1, input2):
934                return self.linear(input1.elem1) + self.linear(input1.elem2)
935
936        input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5))
937        model = CustomModule()
938        with self.assertRaisesRegex(
939            RuntimeError,
940            "ExpandedWeights cannot compute the batch size from the inputs",
941        ):
942            call_for_per_sample_grads(model)(input, "")
943
944        # would prefer for it to error because input is not pytree-able but that's hard to detect
945        with self.assertRaisesRegex(
946            RuntimeError, "Expected ExpandedWeights to have batch size matching input"
947        ):
948            call_for_per_sample_grads(model)(input, torch.randn(5))
949
950        model = CustomModule()  # TODO: functional call bug, sam will fix
951        call_for_per_sample_grads(model)(input, torch.randn(4, 5))
952        model = CustomModule()
953        call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
954
955
956class ContextManagerTests(TestBase):
957    def __init__(self, *args, **kwargs):
958        self.test_cpu = kwargs.get("test_cpu", True)
959        self.test_cuda = kwargs.get("test_cuda", True)
960        super().__init__(*args, **kwargs)
961
962    @property
963    def constructor_args(self):
964        return self._get_arg("constructor_args", False)
965
966    def test_context_manager(self, test_case, device):
967        kwargs = {"device": device, "dtype": torch.double}
968        module = self.constructor(*self.constructor_args).to(**kwargs)
969        if "Embedding" in self.get_name():
970            kwargs["dtype"] = torch.long
971        input = self._get_input().to(**kwargs)
972        if len(input.shape) == 0 or input.shape[0] == 0:
973            raise unittest.SkipTest(
974                "Can't get per sample gradients when no batch dim or batch dim is 0"
975            )
976        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
977            raise unittest.SkipTest(
978                "Can't get per sample gradients for input of rank 1"
979            )
980        test_case._do_test(module, input)
981
982    def test_context_manager_multiple_inputs(self, test_case, device):
983        module = self.constructor(*self.constructor_args).to(device)
984        input = self._get_input()
985        if len(input.shape) == 0 or input.shape[0] == 0:
986            raise unittest.SkipTest(
987                "Can't get per sample gradients when no batch dim or batch dim is 0"
988            )
989        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
990            raise unittest.SkipTest(
991                "Can't get per sample gradients for input of rank 1"
992            )
993        test_case._do_test_multi_input(module, input)
994
995
996def filter_supported_tests(t):
997    supported_modules = [
998        "Linear",
999        "Conv1d",
1000        "Conv2d",
1001        "Conv3d",
1002        "Embedding",
1003        "LayerNorm",
1004        "GroupNorm",
1005        "InstanceNorm",
1006    ]
1007    if "module_name" in t and t["module_name"] in supported_modules:
1008        return True
1009
1010
1011# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
1012# These currently use the legacy nn tests
1013supported_tests = [
1014    t for t in module_tests + new_module_tests if filter_supported_tests(t)
1015]
1016for test_param in supported_tests:
1017    if "constructor" not in test_param:
1018        name = test_param.pop("module_name")
1019        test_param["constructor"] = getattr(nn, name)
1020    decorator = test_param.pop("decorator", lambda test: test)
1021    test = ContextManagerTests(**test_param)
1022    test_name = test.get_name()
1023    if hasattr(TestExpandedWeightModule, test_name):
1024        raise RuntimeError("Found two tests with the same name: " + test_name)
1025    test_name_multi_input = test.get_name() + "_multiple_inputs"
1026    if hasattr(TestExpandedWeightModule, test_name_multi_input):
1027        raise RuntimeError("Found two tests with the same name: " + test_name)
1028    if test.test_cpu:
1029        setattr(
1030            TestExpandedWeightModule,
1031            test_name,
1032            decorator(lambda self, test=test: test.test_context_manager(self, "cpu")),
1033        )
1034        setattr(
1035            TestExpandedWeightModule,
1036            test_name_multi_input,
1037            decorator(
1038                lambda self, test=test: test.test_context_manager_multiple_inputs(
1039                    self, "cpu"
1040                )
1041            ),
1042        )
1043    if TEST_CUDA and test.test_cuda:
1044        # since this checks derivatives, only use double for precision
1045        setattr(
1046            TestExpandedWeightModule,
1047            test_name + "_cuda_double",
1048            decorator(lambda self, test=test: test.test_context_manager(self, "cuda")),
1049        )
1050
1051# ------------- HELPER FUNCTIONS -----------------
1052
1053
1054def run_op(op, input, *args, **kwargs):
1055    r"""
1056    OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative
1057    of the weight, not the input, which can't be differentiable since its dtype is int. Calls op,
1058    using the special ordering that Embedding's OpInfo expects for that case.
1059    """
1060    if op.name == "nn.functional.embedding":
1061        return op(args[0], input, **kwargs)
1062    else:
1063        return op(input, *args, **kwargs)
1064
1065
1066def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
1067    def expanded_weight_or_clone(arg):
1068        if is_diff_tensor(arg):
1069            return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
1070        return clone_if_tensor(arg)
1071
1072    ew_input = clone_if_tensor(sample_input.input)
1073    ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
1074    ew_kwargs = {
1075        name: expanded_weight_or_clone(arg)
1076        for (name, arg) in sample_input.kwargs.items()
1077    }
1078    return ew_input, ew_args, ew_kwargs
1079
1080
1081def supported_inputs(op, sample_inputs, supported_inputs=True):
1082    r"""
1083    ExpandedWeights currently does not support some use cases when there's no batch dimension or
1084    operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
1085    """
1086
1087    def filter_fn(input):
1088        convolutions = [
1089            "nn.functional.conv1d",
1090            "nn.functional.conv2d",
1091            "nn.functional.conv3d",
1092        ]
1093        batched_input_size = dict(zip(convolutions, [3, 4, 5]))
1094        if op.name == "nn.functional.linear":
1095            is_supported_input = (
1096                input.input.dim() > 1
1097            )  # input of rank 1 means no batch dim
1098        elif op.name == "nn.functional.layer_norm":
1099            normalized_shape = input.args[0]
1100            is_supported_input = (
1101                input.input.shape != normalized_shape
1102            )  # would cause inter-batch operations
1103        elif op.name in convolutions:
1104            # currently can't deal with padding computation on Python level
1105            is_supported_input = input.input.dim() == batched_input_size[op.name]
1106        elif op.name == "nn.functional.embedding":
1107            idx = input.args[0]
1108            is_supported_input = len(idx.shape) > 1  # there's no batch size
1109        else:
1110            is_supported_input = True
1111        is_supported_input = (
1112            is_supported_input and input.input.shape[0] > 0
1113        )  # 0 is not a valid batch size
1114        return is_supported_input if supported_inputs else not is_supported_input
1115
1116    return [input for input in sample_inputs if filter_fn(input)]
1117
1118
1119def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
1120    # get per sample grads by getting derivative for each input in a for loop
1121    per_sample_grad = []
1122    for i in range(batch_size):
1123        per_sample_input = input[i]
1124        result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
1125        diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
1126        diff_input_list = [
1127            i
1128            for i in diff_input_list
1129            if isinstance(i, torch.Tensor) and i.requires_grad
1130        ]
1131        per_sample_grad.append(
1132            torch.autograd.grad(
1133                result, diff_input_list, torch.ones_like(result), allow_unused=True
1134            )
1135        )
1136    if len(per_sample_grad) == batch_size:
1137        per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad))
1138    return per_sample_grad
1139
1140
1141def is_diff_tensor(t):
1142    return isinstance(t, ExpandedWeight) or (
1143        isinstance(t, torch.Tensor) and t.requires_grad
1144    )
1145
1146
1147def clone_if_tensor(t):
1148    if isinstance(t, torch.Tensor):
1149        res = torch.clone(t).detach()
1150        res.requires_grad_(t.requires_grad)
1151        return res
1152    else:
1153        return t
1154
1155
1156instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
1157instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
1158instantiate_device_type_tests(TestExpandedWeightModule, globals())
1159if __name__ == "__main__":
1160    run_tests()
1161