xref: /aosp_15_r20/external/pytorch/test/test_expanded_weights.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Workerimport unittest
3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
5*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, product
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import CrossEntropyLoss
11*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils._expanded_weights import ExpandedWeight
12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils._expanded_weights.expanded_weights_utils import (
13*da0073e9SAndroid Build Coastguard Worker    forward_helper,
14*da0073e9SAndroid Build Coastguard Worker    set_grad_sample_if_exists,
15*da0073e9SAndroid Build Coastguard Worker    standard_kwargs,
16*da0073e9SAndroid Build Coastguard Worker    sum_over_all_but_batch_and_last_n,
17*da0073e9SAndroid Build Coastguard Worker    unpack_expanded_weight_or_tensor,
18*da0073e9SAndroid Build Coastguard Worker)
19*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils._per_sample_grad import call_for_per_sample_grads
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
22*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
23*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
24*da0073e9SAndroid Build Coastguard Worker    ops,
25*da0073e9SAndroid Build Coastguard Worker)
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db, SampleInput
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_modules import module_db, modules
28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
29*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
30*da0073e9SAndroid Build Coastguard Worker    freeze_rng_state,
31*da0073e9SAndroid Build Coastguard Worker    make_tensor,
32*da0073e9SAndroid Build Coastguard Worker    parametrize,
33*da0073e9SAndroid Build Coastguard Worker    run_tests,
34*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
35*da0073e9SAndroid Build Coastguard Worker    TestCase,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map_only
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerclass TestContext:
41*da0073e9SAndroid Build Coastguard Worker    pass
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workerclass TestExpandedWeightHelperFunction(TestCase):
45*da0073e9SAndroid Build Coastguard Worker    def test_forward_helper(self, device):
46*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4, device=device)
47*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(5, 4, device=device)
48*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(5, device=device)
49*da0073e9SAndroid Build Coastguard Worker        for weight_batched, bias_batched in product([True, False], [True, False]):
50*da0073e9SAndroid Build Coastguard Worker            maybe_batched_weight = weight
51*da0073e9SAndroid Build Coastguard Worker            maybe_batched_bias = bias
52*da0073e9SAndroid Build Coastguard Worker            if weight_batched:
53*da0073e9SAndroid Build Coastguard Worker                maybe_batched_weight = ExpandedWeight(
54*da0073e9SAndroid Build Coastguard Worker                    weight.clone().requires_grad_(), 3, loss_reduction="sum"
55*da0073e9SAndroid Build Coastguard Worker                )
56*da0073e9SAndroid Build Coastguard Worker            if bias_batched:
57*da0073e9SAndroid Build Coastguard Worker                maybe_batched_bias = ExpandedWeight(
58*da0073e9SAndroid Build Coastguard Worker                    bias.clone().requires_grad_(), 3, loss_reduction="sum"
59*da0073e9SAndroid Build Coastguard Worker                )
60*da0073e9SAndroid Build Coastguard Worker            args = (input, maybe_batched_weight, maybe_batched_bias)
61*da0073e9SAndroid Build Coastguard Worker            expanded_args, expanded_kwargs = standard_kwargs(("bias",), args)
62*da0073e9SAndroid Build Coastguard Worker            res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
63*da0073e9SAndroid Build Coastguard Worker            expected = nn.functional.linear(input, weight, bias)
64*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(expanded_args), 2)
67*da0073e9SAndroid Build Coastguard Worker            assert expanded_args[0] is args[0]  # avoids property checks in assertEquals
68*da0073e9SAndroid Build Coastguard Worker            assert expanded_args[1] is args[1]  # avoids property checks in assertEquals
69*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(expanded_kwargs), 1)
70*da0073e9SAndroid Build Coastguard Worker            assert (
71*da0073e9SAndroid Build Coastguard Worker                expanded_kwargs["bias"] is args[2]
72*da0073e9SAndroid Build Coastguard Worker            )  # avoids property checks in assertEquals
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def test_forward_helper_failure_args(self, device):
75*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(5, 4, device=device)
76*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(5, device=device)
77*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
78*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"do not support inputs that are also ExpandedWeights."
79*da0073e9SAndroid Build Coastguard Worker        ):
80*da0073e9SAndroid Build Coastguard Worker            input = ExpandedWeight(
81*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum"
82*da0073e9SAndroid Build Coastguard Worker            )
83*da0073e9SAndroid Build Coastguard Worker            expanded_args, expanded_kwargs = standard_kwargs(
84*da0073e9SAndroid Build Coastguard Worker                ("bias",), (input, weight, bias)
85*da0073e9SAndroid Build Coastguard Worker            )
86*da0073e9SAndroid Build Coastguard Worker            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
87*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
88*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"requires a Tensor as the first input"
89*da0073e9SAndroid Build Coastguard Worker        ):
90*da0073e9SAndroid Build Coastguard Worker            expanded_args, expanded_kwargs = standard_kwargs(
91*da0073e9SAndroid Build Coastguard Worker                ("bias",), (3, weight, bias)
92*da0073e9SAndroid Build Coastguard Worker            )
93*da0073e9SAndroid Build Coastguard Worker            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
94*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
95*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"requires a batch dimension but got an input of size 0"
96*da0073e9SAndroid Build Coastguard Worker        ):
97*da0073e9SAndroid Build Coastguard Worker            expanded_args, expanded_kwargs = standard_kwargs(
98*da0073e9SAndroid Build Coastguard Worker                ("bias",), (torch.tensor(3), weight, bias)
99*da0073e9SAndroid Build Coastguard Worker            )
100*da0073e9SAndroid Build Coastguard Worker            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
101*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
102*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"0 is not a valid batch size for Expanded Weights"
103*da0073e9SAndroid Build Coastguard Worker        ):
104*da0073e9SAndroid Build Coastguard Worker            expanded_args, expanded_kwargs = standard_kwargs(
105*da0073e9SAndroid Build Coastguard Worker                ("bias",), (torch.randn(0, 1, 2), weight, bias)
106*da0073e9SAndroid Build Coastguard Worker            )
107*da0073e9SAndroid Build Coastguard Worker            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
108*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
109*da0073e9SAndroid Build Coastguard Worker        for weight_batched, bias_batched in product([True, False], [True, False]):
110*da0073e9SAndroid Build Coastguard Worker            if not weight_batched and not bias_batched:
111*da0073e9SAndroid Build Coastguard Worker                continue
112*da0073e9SAndroid Build Coastguard Worker            maybe_batched_weight = weight
113*da0073e9SAndroid Build Coastguard Worker            maybe_batched_bias = bias
114*da0073e9SAndroid Build Coastguard Worker            if weight_batched:
115*da0073e9SAndroid Build Coastguard Worker                maybe_batched_weight = ExpandedWeight(
116*da0073e9SAndroid Build Coastguard Worker                    weight.clone().requires_grad_(), 4, loss_reduction="sum"
117*da0073e9SAndroid Build Coastguard Worker                )
118*da0073e9SAndroid Build Coastguard Worker            if bias_batched:
119*da0073e9SAndroid Build Coastguard Worker                maybe_batched_bias = ExpandedWeight(
120*da0073e9SAndroid Build Coastguard Worker                    bias.clone().requires_grad_(), 4, loss_reduction="sum"
121*da0073e9SAndroid Build Coastguard Worker                )
122*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
123*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
124*da0073e9SAndroid Build Coastguard Worker                r"Expected ExpandedWeights to have batch size matching input",
125*da0073e9SAndroid Build Coastguard Worker            ):
126*da0073e9SAndroid Build Coastguard Worker                expanded_args, expanded_kwargs = standard_kwargs(
127*da0073e9SAndroid Build Coastguard Worker                    ("bias",), (input, maybe_batched_weight, maybe_batched_bias)
128*da0073e9SAndroid Build Coastguard Worker                )
129*da0073e9SAndroid Build Coastguard Worker                forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_sample_if_exists(self, device):
132*da0073e9SAndroid Build Coastguard Worker        def test_fn(a):
133*da0073e9SAndroid Build Coastguard Worker            return grad_sample
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        orig_weight = torch.randn(4, device=device, requires_grad=True)
136*da0073e9SAndroid Build Coastguard Worker        expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
137*da0073e9SAndroid Build Coastguard Worker        grad_sample = torch.randn(3)
138*da0073e9SAndroid Build Coastguard Worker        set_grad_sample_if_exists(expanded_weight, test_fn)
139*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(orig_weight, "grad_sample"))
140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_weight.grad_sample, grad_sample)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        basic_tensor = torch.randn(4, device=device)
143*da0073e9SAndroid Build Coastguard Worker        set_grad_sample_if_exists(basic_tensor, test_fn)
144*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(basic_tensor, "grad_sample"))
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        non_tensor = 3
147*da0073e9SAndroid Build Coastguard Worker        set_grad_sample_if_exists(non_tensor, test_fn)
148*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(non_tensor, "grad_sample"))
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_sample_if_exists_failure(self, device):
151*da0073e9SAndroid Build Coastguard Worker        def test_fn(a):
152*da0073e9SAndroid Build Coastguard Worker            return True
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        grad_tensor = torch.randn(4, requires_grad=True, device=device)
155*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
156*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
157*da0073e9SAndroid Build Coastguard Worker            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
158*da0073e9SAndroid Build Coastguard Worker        ):
159*da0073e9SAndroid Build Coastguard Worker            set_grad_sample_if_exists(grad_tensor, test_fn)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def test_unpack_expanded_weight_or_tensor(self, device):
162*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, requires_grad=True, device=device)
163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
164*da0073e9SAndroid Build Coastguard Worker            input,
165*da0073e9SAndroid Build Coastguard Worker            unpack_expanded_weight_or_tensor(
166*da0073e9SAndroid Build Coastguard Worker                ExpandedWeight(input, 3, loss_reduction="sum")
167*da0073e9SAndroid Build Coastguard Worker            ),
168*da0073e9SAndroid Build Coastguard Worker        )
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        input.requires_grad_(False)
171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
172*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
175*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, requires_grad=True, device=device)
176*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
177*da0073e9SAndroid Build Coastguard Worker            unpack_expanded_weight_or_tensor(
178*da0073e9SAndroid Build Coastguard Worker                ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input
179*da0073e9SAndroid Build Coastguard Worker            )
180*da0073e9SAndroid Build Coastguard Worker        )
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        input.requires_grad_(False)
183*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
184*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
185*da0073e9SAndroid Build Coastguard Worker            unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None
186*da0073e9SAndroid Build Coastguard Worker        )
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def test_unpack_expanded_weight_or_tensor_failure(self, device):
189*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, requires_grad=True, device=device)
190*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
191*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
192*da0073e9SAndroid Build Coastguard Worker            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
193*da0073e9SAndroid Build Coastguard Worker        ):
194*da0073e9SAndroid Build Coastguard Worker            unpack_expanded_weight_or_tensor(input)
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
197*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
198*da0073e9SAndroid Build Coastguard Worker            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
199*da0073e9SAndroid Build Coastguard Worker        ):
200*da0073e9SAndroid Build Coastguard Worker            unpack_expanded_weight_or_tensor(input, lambda x: x is input)
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    def test_sum_over_all_but_batch_and_last_n(self, device):
203*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 2, 3, 4, 5, device=device)
204*da0073e9SAndroid Build Coastguard Worker        res = sum_over_all_but_batch_and_last_n(input, 2)
205*da0073e9SAndroid Build Coastguard Worker        expected = input.sum((1, 2))
206*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        res = sum_over_all_but_batch_and_last_n(input, 0)
209*da0073e9SAndroid Build Coastguard Worker        expected = input.sum((1, 2, 3, 4))
210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker        res = sum_over_all_but_batch_and_last_n(input, 4)
213*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, input)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Workerclass TestExpandedWeightFunctional(TestCase):
217*da0073e9SAndroid Build Coastguard Worker    def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
218*da0073e9SAndroid Build Coastguard Worker        input = sample_input.input
219*da0073e9SAndroid Build Coastguard Worker        args = sample_input.args
220*da0073e9SAndroid Build Coastguard Worker        kwargs = sample_input.kwargs
221*da0073e9SAndroid Build Coastguard Worker        batch_size = input.shape[0] if len(input.shape) > 1 else 1
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        # get per sample grads with ExpandedWeights objects
224*da0073e9SAndroid Build Coastguard Worker        loss_reduction = "sum" if reduction == torch.sum else "mean"
225*da0073e9SAndroid Build Coastguard Worker        (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
226*da0073e9SAndroid Build Coastguard Worker            sample_input, batch_size, loss_reduction
227*da0073e9SAndroid Build Coastguard Worker        )
228*da0073e9SAndroid Build Coastguard Worker        diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
229*da0073e9SAndroid Build Coastguard Worker        diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
230*da0073e9SAndroid Build Coastguard Worker        diff_input_list = [
231*da0073e9SAndroid Build Coastguard Worker            i.orig_weight if isinstance(i, ExpandedWeight) else i
232*da0073e9SAndroid Build Coastguard Worker            for i in diff_input_list
233*da0073e9SAndroid Build Coastguard Worker        ]
234*da0073e9SAndroid Build Coastguard Worker        if not diff_input_list:
235*da0073e9SAndroid Build Coastguard Worker            return
236*da0073e9SAndroid Build Coastguard Worker        result = run_op(op, ew_input, *ew_args, **ew_kwargs)
237*da0073e9SAndroid Build Coastguard Worker        reduction(
238*da0073e9SAndroid Build Coastguard Worker            result
239*da0073e9SAndroid Build Coastguard Worker        ).backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
240*da0073e9SAndroid Build Coastguard Worker        expanded_weight_grad = tuple(
241*da0073e9SAndroid Build Coastguard Worker            i.grad_sample if hasattr(i, "grad_sample") else i.grad
242*da0073e9SAndroid Build Coastguard Worker            for i in diff_input_list
243*da0073e9SAndroid Build Coastguard Worker        )
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        # get per sample grads with for loop
246*da0073e9SAndroid Build Coastguard Worker        func = partial(run_op, op)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        per_sample_grad = for_loop_per_sample_grad(
249*da0073e9SAndroid Build Coastguard Worker            batch_size, reduction, input, func, *args, **kwargs
250*da0073e9SAndroid Build Coastguard Worker        )
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        # check equality
253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
254*da0073e9SAndroid Build Coastguard Worker        if loss_reduction == "mean":
255*da0073e9SAndroid Build Coastguard Worker            # don't check equality of `input.grad`s since these vanilla tensors won't be scaled
256*da0073e9SAndroid Build Coastguard Worker            expanded_weight_grad = expanded_weight_grad[1:]
257*da0073e9SAndroid Build Coastguard Worker            per_sample_grad = per_sample_grad[1:]
258*da0073e9SAndroid Build Coastguard Worker        for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad):
259*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_grad, expected_grad)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    @ops(
262*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_expanded_weight, op_db),
263*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
264*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.double,),
265*da0073e9SAndroid Build Coastguard Worker    )
266*da0073e9SAndroid Build Coastguard Worker    def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
267*da0073e9SAndroid Build Coastguard Worker        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
268*da0073e9SAndroid Build Coastguard Worker        for sample_input in supported_inputs(op, sample_inputs):
269*da0073e9SAndroid Build Coastguard Worker            if (
270*da0073e9SAndroid Build Coastguard Worker                op.name == "nn.functional.embedding"
271*da0073e9SAndroid Build Coastguard Worker            ):  # embedding flips its argument order for autograd tests
272*da0073e9SAndroid Build Coastguard Worker                sample_input = SampleInput(
273*da0073e9SAndroid Build Coastguard Worker                    sample_input.args[0],
274*da0073e9SAndroid Build Coastguard Worker                    args=(sample_input.input,),
275*da0073e9SAndroid Build Coastguard Worker                    kwargs=sample_input.kwargs,
276*da0073e9SAndroid Build Coastguard Worker                )
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    @ops(
281*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_expanded_weight, op_db),
282*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
283*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.double,),
284*da0073e9SAndroid Build Coastguard Worker    )
285*da0073e9SAndroid Build Coastguard Worker    def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
286*da0073e9SAndroid Build Coastguard Worker        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
287*da0073e9SAndroid Build Coastguard Worker        for sample_input in supported_inputs(op, sample_inputs):
288*da0073e9SAndroid Build Coastguard Worker            if (
289*da0073e9SAndroid Build Coastguard Worker                op.name == "nn.functional.embedding"
290*da0073e9SAndroid Build Coastguard Worker            ):  # embedding flips its argument order for autograd tests
291*da0073e9SAndroid Build Coastguard Worker                sample_input = SampleInput(
292*da0073e9SAndroid Build Coastguard Worker                    sample_input.args[0],
293*da0073e9SAndroid Build Coastguard Worker                    args=(sample_input.input,),
294*da0073e9SAndroid Build Coastguard Worker                    kwargs=sample_input.kwargs,
295*da0073e9SAndroid Build Coastguard Worker                )
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    @ops(
300*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_expanded_weight, op_db),
301*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
302*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.double,),
303*da0073e9SAndroid Build Coastguard Worker    )
304*da0073e9SAndroid Build Coastguard Worker    def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
305*da0073e9SAndroid Build Coastguard Worker        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
306*da0073e9SAndroid Build Coastguard Worker        for sample_input in supported_inputs(op, sample_inputs):
307*da0073e9SAndroid Build Coastguard Worker            if (
308*da0073e9SAndroid Build Coastguard Worker                op.name == "nn.functional.embedding"
309*da0073e9SAndroid Build Coastguard Worker            ):  # embedding flips its argument order for autograd tests
310*da0073e9SAndroid Build Coastguard Worker                sample_input = SampleInput(
311*da0073e9SAndroid Build Coastguard Worker                    sample_input.args[0],
312*da0073e9SAndroid Build Coastguard Worker                    args=(sample_input.input,),
313*da0073e9SAndroid Build Coastguard Worker                    kwargs=sample_input.kwargs,
314*da0073e9SAndroid Build Coastguard Worker                )
315*da0073e9SAndroid Build Coastguard Worker            sample_input.input.requires_grad_(False)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Checking error message doesn't work with dynamo")
320*da0073e9SAndroid Build Coastguard Worker    @ops(
321*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_expanded_weight, op_db),
322*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
323*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.double,),
324*da0073e9SAndroid Build Coastguard Worker    )
325*da0073e9SAndroid Build Coastguard Worker    def test_unsupported_expand_weights(self, device, dtype, op):
326*da0073e9SAndroid Build Coastguard Worker        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
327*da0073e9SAndroid Build Coastguard Worker        unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False)
328*da0073e9SAndroid Build Coastguard Worker        for sample_input in unsupported_inputs:
329*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"):
330*da0073e9SAndroid Build Coastguard Worker                if (
331*da0073e9SAndroid Build Coastguard Worker                    op.name == "nn.functional.embedding"
332*da0073e9SAndroid Build Coastguard Worker                ):  # embedding flips its argument order for autograd tests
333*da0073e9SAndroid Build Coastguard Worker                    sample_input = SampleInput(
334*da0073e9SAndroid Build Coastguard Worker                        sample_input.args[0],
335*da0073e9SAndroid Build Coastguard Worker                        args=(sample_input.input,),
336*da0073e9SAndroid Build Coastguard Worker                        kwargs=sample_input.kwargs,
337*da0073e9SAndroid Build Coastguard Worker                    )
338*da0073e9SAndroid Build Coastguard Worker                input = sample_input.input
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker                batch_size = input.shape[0] if len(input.shape) > 1 else 1
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker                # get per sample grads with ExpandedWeights objects
343*da0073e9SAndroid Build Coastguard Worker                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
344*da0073e9SAndroid Build Coastguard Worker                    sample_input, batch_size
345*da0073e9SAndroid Build Coastguard Worker                )
346*da0073e9SAndroid Build Coastguard Worker                result = run_op(op, ew_input, *ew_args, **ew_kwargs)
347*da0073e9SAndroid Build Coastguard Worker                diff_input_list = (
348*da0073e9SAndroid Build Coastguard Worker                    (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
349*da0073e9SAndroid Build Coastguard Worker                )
350*da0073e9SAndroid Build Coastguard Worker                diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
351*da0073e9SAndroid Build Coastguard Worker                diff_input_list = [
352*da0073e9SAndroid Build Coastguard Worker                    i.orig_weight if isinstance(i, ExpandedWeight) else i
353*da0073e9SAndroid Build Coastguard Worker                    for i in diff_input_list
354*da0073e9SAndroid Build Coastguard Worker                ]
355*da0073e9SAndroid Build Coastguard Worker                result.sum().backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    @ops(
358*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported
359*da0073e9SAndroid Build Coastguard Worker    )
360*da0073e9SAndroid Build Coastguard Worker    def test_expanded_weight_forward(self, device, dtype, op):
361*da0073e9SAndroid Build Coastguard Worker        sample_inputs = op.sample_inputs(device, dtype)
362*da0073e9SAndroid Build Coastguard Worker        for sample_input in supported_inputs(op, sample_inputs):
363*da0073e9SAndroid Build Coastguard Worker            if (
364*da0073e9SAndroid Build Coastguard Worker                op.name == "nn.functional.embedding"
365*da0073e9SAndroid Build Coastguard Worker            ):  # embedding flips its argument order for autograd tests
366*da0073e9SAndroid Build Coastguard Worker                sample_input = SampleInput(
367*da0073e9SAndroid Build Coastguard Worker                    sample_input.args[0].clone(),
368*da0073e9SAndroid Build Coastguard Worker                    args=(sample_input.input.clone(),),
369*da0073e9SAndroid Build Coastguard Worker                    kwargs=sample_input.kwargs,
370*da0073e9SAndroid Build Coastguard Worker                )
371*da0073e9SAndroid Build Coastguard Worker                if (
372*da0073e9SAndroid Build Coastguard Worker                    "cuda" in device
373*da0073e9SAndroid Build Coastguard Worker                    and "max_norm" in sample_input.kwargs
374*da0073e9SAndroid Build Coastguard Worker                    and "padding_idx" in sample_input.kwargs
375*da0073e9SAndroid Build Coastguard Worker                ):
376*da0073e9SAndroid Build Coastguard Worker                    self.skipTest(
377*da0073e9SAndroid Build Coastguard Worker                        "embedding is non-determinstic in this case, see issue #74679"
378*da0073e9SAndroid Build Coastguard Worker                    )
379*da0073e9SAndroid Build Coastguard Worker            batch_size = (
380*da0073e9SAndroid Build Coastguard Worker                sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
381*da0073e9SAndroid Build Coastguard Worker            )
382*da0073e9SAndroid Build Coastguard Worker            for loss_reduction in ["sum", "mean"]:
383*da0073e9SAndroid Build Coastguard Worker                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
384*da0073e9SAndroid Build Coastguard Worker                    sample_input, batch_size, loss_reduction
385*da0073e9SAndroid Build Coastguard Worker                )
386*da0073e9SAndroid Build Coastguard Worker                expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
387*da0073e9SAndroid Build Coastguard Worker                normal_result = run_op(
388*da0073e9SAndroid Build Coastguard Worker                    op, sample_input.input, *sample_input.args, **sample_input.kwargs
389*da0073e9SAndroid Build Coastguard Worker                )
390*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expanded_weight_result, normal_result)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker    def test_expanded_weight_error(self, device):
393*da0073e9SAndroid Build Coastguard Worker        batch_size = 3
394*da0073e9SAndroid Build Coastguard Worker        sample_input = make_tensor(
395*da0073e9SAndroid Build Coastguard Worker            (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True
396*da0073e9SAndroid Build Coastguard Worker        )
397*da0073e9SAndroid Build Coastguard Worker        sample_weight = make_tensor(
398*da0073e9SAndroid Build Coastguard Worker            (4), dtype=torch.float32, device=device, requires_grad=True
399*da0073e9SAndroid Build Coastguard Worker        )
400*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
401*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expanded Weights encountered but cannot handle function"
402*da0073e9SAndroid Build Coastguard Worker        ):
403*da0073e9SAndroid Build Coastguard Worker            torch.add(
404*da0073e9SAndroid Build Coastguard Worker                sample_input,
405*da0073e9SAndroid Build Coastguard Worker                ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"),
406*da0073e9SAndroid Build Coastguard Worker            )
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker    def _test_embedding_model(self, model, num_embedding, device):
409*da0073e9SAndroid Build Coastguard Worker        batch_size = 32
410*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
411*da0073e9SAndroid Build Coastguard Worker        return self._test_model(
412*da0073e9SAndroid Build Coastguard Worker            partial(model, num_embedding=num_embedding), batch_size, input, device
413*da0073e9SAndroid Build Coastguard Worker        )
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    def _test_conv_model(
416*da0073e9SAndroid Build Coastguard Worker        self,
417*da0073e9SAndroid Build Coastguard Worker        model,
418*da0073e9SAndroid Build Coastguard Worker        input_size,
419*da0073e9SAndroid Build Coastguard Worker        num_dim,
420*da0073e9SAndroid Build Coastguard Worker        device,
421*da0073e9SAndroid Build Coastguard Worker        loss_reduction="sum",
422*da0073e9SAndroid Build Coastguard Worker        atol=1e-4,
423*da0073e9SAndroid Build Coastguard Worker        rtol=5e-5,
424*da0073e9SAndroid Build Coastguard Worker    ):
425*da0073e9SAndroid Build Coastguard Worker        batch_size = 32
426*da0073e9SAndroid Build Coastguard Worker        input_ending = [input_size] * num_dim
427*da0073e9SAndroid Build Coastguard Worker        input = torch.randn([batch_size, 3] + input_ending, device=device)
428*da0073e9SAndroid Build Coastguard Worker        return self._test_model(
429*da0073e9SAndroid Build Coastguard Worker            partial(model, num_dim=num_dim),
430*da0073e9SAndroid Build Coastguard Worker            batch_size,
431*da0073e9SAndroid Build Coastguard Worker            input,
432*da0073e9SAndroid Build Coastguard Worker            device,
433*da0073e9SAndroid Build Coastguard Worker            loss_reduction,
434*da0073e9SAndroid Build Coastguard Worker            atol,
435*da0073e9SAndroid Build Coastguard Worker            rtol,
436*da0073e9SAndroid Build Coastguard Worker        )
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker    def _test_model(
439*da0073e9SAndroid Build Coastguard Worker        self,
440*da0073e9SAndroid Build Coastguard Worker        model,
441*da0073e9SAndroid Build Coastguard Worker        batch_size,
442*da0073e9SAndroid Build Coastguard Worker        input,
443*da0073e9SAndroid Build Coastguard Worker        device,
444*da0073e9SAndroid Build Coastguard Worker        loss_reduction="sum",
445*da0073e9SAndroid Build Coastguard Worker        atol=1e-4,
446*da0073e9SAndroid Build Coastguard Worker        rtol=5e-5,
447*da0073e9SAndroid Build Coastguard Worker    ):
448*da0073e9SAndroid Build Coastguard Worker        model = model(10).to(device)
449*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(0, 10, (batch_size,), device=device)
450*da0073e9SAndroid Build Coastguard Worker        criterion = CrossEntropyLoss(reduction=loss_reduction)
451*da0073e9SAndroid Build Coastguard Worker        result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input)
452*da0073e9SAndroid Build Coastguard Worker        loss = criterion(result, targets)
453*da0073e9SAndroid Build Coastguard Worker        loss.backward()
454*da0073e9SAndroid Build Coastguard Worker        result = []
455*da0073e9SAndroid Build Coastguard Worker        for weight in model.parameters():
456*da0073e9SAndroid Build Coastguard Worker            result.append(weight.grad_sample)
457*da0073e9SAndroid Build Coastguard Worker            del weight.grad_sample
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        expected = []
460*da0073e9SAndroid Build Coastguard Worker        for i in range(batch_size):
461*da0073e9SAndroid Build Coastguard Worker            loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0))
462*da0073e9SAndroid Build Coastguard Worker            expected.append(
463*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))
464*da0073e9SAndroid Build Coastguard Worker            )
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        expected = [torch.stack(grad) for grad in zip(*expected)]
467*da0073e9SAndroid Build Coastguard Worker        for res, exp in zip(result, expected):
468*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, exp, atol=atol, rtol=rtol)
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    def _compute_tolerances(self, device):
471*da0073e9SAndroid Build Coastguard Worker        is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(
472*da0073e9SAndroid Build Coastguard Worker            0
473*da0073e9SAndroid Build Coastguard Worker        ) == (8, 6)
474*da0073e9SAndroid Build Coastguard Worker        return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    @tf32_off()
477*da0073e9SAndroid Build Coastguard Worker    def test_cnn_model_sum(self, device):
478*da0073e9SAndroid Build Coastguard Worker        def convnet(num_classes, num_dim):
479*da0073e9SAndroid Build Coastguard Worker            return nn.Sequential(
480*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
481*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
482*da0073e9SAndroid Build Coastguard Worker                nn.AvgPool2d(kernel_size=2, stride=2),
483*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
484*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
485*da0073e9SAndroid Build Coastguard Worker                nn.AvgPool2d(kernel_size=2, stride=2),
486*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
487*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
488*da0073e9SAndroid Build Coastguard Worker                nn.AvgPool2d(kernel_size=2, stride=2),
489*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
490*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
491*da0073e9SAndroid Build Coastguard Worker                nn.AdaptiveAvgPool2d((1, 1)),
492*da0073e9SAndroid Build Coastguard Worker                nn.Flatten(start_dim=1, end_dim=-1),
493*da0073e9SAndroid Build Coastguard Worker                nn.Linear(128, num_classes, bias=True),
494*da0073e9SAndroid Build Coastguard Worker            )
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        atol, rtol = self._compute_tolerances(device)
497*da0073e9SAndroid Build Coastguard Worker        return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker    @tf32_off()
500*da0073e9SAndroid Build Coastguard Worker    def test_cnn_model_mean(self, device):
501*da0073e9SAndroid Build Coastguard Worker        def convnet(num_classes, num_dim):
502*da0073e9SAndroid Build Coastguard Worker            return nn.Sequential(
503*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
504*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
505*da0073e9SAndroid Build Coastguard Worker                nn.AvgPool2d(kernel_size=2, stride=2),
506*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
507*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
508*da0073e9SAndroid Build Coastguard Worker                nn.AvgPool2d(kernel_size=2, stride=2),
509*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
510*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
511*da0073e9SAndroid Build Coastguard Worker                nn.AvgPool2d(kernel_size=2, stride=2),
512*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
513*da0073e9SAndroid Build Coastguard Worker                nn.ReLU(),
514*da0073e9SAndroid Build Coastguard Worker                nn.AdaptiveAvgPool2d((1, 1)),
515*da0073e9SAndroid Build Coastguard Worker                nn.Flatten(start_dim=1, end_dim=-1),
516*da0073e9SAndroid Build Coastguard Worker                nn.Linear(128, num_classes, bias=True),
517*da0073e9SAndroid Build Coastguard Worker            )
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        atol, rtol = self._compute_tolerances(device)
520*da0073e9SAndroid Build Coastguard Worker        return self._test_conv_model(
521*da0073e9SAndroid Build Coastguard Worker            convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol
522*da0073e9SAndroid Build Coastguard Worker        )
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker    @parametrize("num_dim", [1, 2, 3])
525*da0073e9SAndroid Build Coastguard Worker    @tf32_off()
526*da0073e9SAndroid Build Coastguard Worker    def test_instance_norm_model(self, num_dim, device):
527*da0073e9SAndroid Build Coastguard Worker        def instance_norm_model(num_classes, num_dim):
528*da0073e9SAndroid Build Coastguard Worker            conv_layer = (
529*da0073e9SAndroid Build Coastguard Worker                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
530*da0073e9SAndroid Build Coastguard Worker            )
531*da0073e9SAndroid Build Coastguard Worker            norm_layer = (
532*da0073e9SAndroid Build Coastguard Worker                nn.InstanceNorm1d
533*da0073e9SAndroid Build Coastguard Worker                if num_dim == 1
534*da0073e9SAndroid Build Coastguard Worker                else nn.InstanceNorm2d
535*da0073e9SAndroid Build Coastguard Worker                if num_dim == 2
536*da0073e9SAndroid Build Coastguard Worker                else nn.InstanceNorm3d
537*da0073e9SAndroid Build Coastguard Worker            )
538*da0073e9SAndroid Build Coastguard Worker            return nn.Sequential(
539*da0073e9SAndroid Build Coastguard Worker                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
540*da0073e9SAndroid Build Coastguard Worker                norm_layer(32, affine=True),
541*da0073e9SAndroid Build Coastguard Worker                nn.Flatten(start_dim=1, end_dim=-1),
542*da0073e9SAndroid Build Coastguard Worker                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
543*da0073e9SAndroid Build Coastguard Worker            )
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker        atol, rtol = self._compute_tolerances(device)
546*da0073e9SAndroid Build Coastguard Worker        return self._test_conv_model(
547*da0073e9SAndroid Build Coastguard Worker            instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
548*da0073e9SAndroid Build Coastguard Worker        )
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    @parametrize("num_dim", [1, 2, 3])
551*da0073e9SAndroid Build Coastguard Worker    @tf32_off()
552*da0073e9SAndroid Build Coastguard Worker    def test_group_norm_model(self, num_dim, device):
553*da0073e9SAndroid Build Coastguard Worker        def group_norm_model(num_classes, num_dim):
554*da0073e9SAndroid Build Coastguard Worker            conv_layer = (
555*da0073e9SAndroid Build Coastguard Worker                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
556*da0073e9SAndroid Build Coastguard Worker            )
557*da0073e9SAndroid Build Coastguard Worker            return nn.Sequential(
558*da0073e9SAndroid Build Coastguard Worker                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
559*da0073e9SAndroid Build Coastguard Worker                nn.GroupNorm(8, 32, affine=True),
560*da0073e9SAndroid Build Coastguard Worker                nn.Flatten(start_dim=1, end_dim=-1),
561*da0073e9SAndroid Build Coastguard Worker                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
562*da0073e9SAndroid Build Coastguard Worker            )
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        atol, rtol = self._compute_tolerances(device)
565*da0073e9SAndroid Build Coastguard Worker        return self._test_conv_model(
566*da0073e9SAndroid Build Coastguard Worker            group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
567*da0073e9SAndroid Build Coastguard Worker        )
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker    @parametrize("num_dim", [1, 2, 3])
570*da0073e9SAndroid Build Coastguard Worker    @tf32_off()
571*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_model(self, num_dim, device):
572*da0073e9SAndroid Build Coastguard Worker        def layer_norm_model(num_classes, num_dim):
573*da0073e9SAndroid Build Coastguard Worker            conv_layer = (
574*da0073e9SAndroid Build Coastguard Worker                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
575*da0073e9SAndroid Build Coastguard Worker            )
576*da0073e9SAndroid Build Coastguard Worker            normalized_shape = [7] * num_dim
577*da0073e9SAndroid Build Coastguard Worker            return nn.Sequential(
578*da0073e9SAndroid Build Coastguard Worker                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
579*da0073e9SAndroid Build Coastguard Worker                nn.LayerNorm(normalized_shape, elementwise_affine=True),
580*da0073e9SAndroid Build Coastguard Worker                nn.Flatten(start_dim=1, end_dim=-1),
581*da0073e9SAndroid Build Coastguard Worker                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
582*da0073e9SAndroid Build Coastguard Worker            )
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        atol, rtol = self._compute_tolerances(device)
585*da0073e9SAndroid Build Coastguard Worker        return self._test_conv_model(
586*da0073e9SAndroid Build Coastguard Worker            layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
587*da0073e9SAndroid Build Coastguard Worker        )
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker    def test_embedding_model(self, device):
590*da0073e9SAndroid Build Coastguard Worker        def embedding_model(num_classes, num_embedding):
591*da0073e9SAndroid Build Coastguard Worker            return nn.Sequential(
592*da0073e9SAndroid Build Coastguard Worker                nn.Embedding(num_embedding, 15),
593*da0073e9SAndroid Build Coastguard Worker                nn.Flatten(start_dim=1, end_dim=-1),
594*da0073e9SAndroid Build Coastguard Worker                nn.Linear(375, num_classes, bias=True),
595*da0073e9SAndroid Build Coastguard Worker            )
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        return self._test_embedding_model(embedding_model, 16, device)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    def test_group_norm_error(self, device):
600*da0073e9SAndroid Build Coastguard Worker        # group norm has to call native_group_norm. This checks that it hits the same errors
601*da0073e9SAndroid Build Coastguard Worker        # that normal group norm would
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker        N = 3
604*da0073e9SAndroid Build Coastguard Worker        C = 5
605*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(N, C)
606*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
607*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expected number of channels in input to be divisible"
608*da0073e9SAndroid Build Coastguard Worker        ):
609*da0073e9SAndroid Build Coastguard Worker            F.group_norm(inp, 2)  # 5 is not divisible by 2
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Workerclass TestExpandedWeightModule(TestCase):
613*da0073e9SAndroid Build Coastguard Worker    def _do_test(
614*da0073e9SAndroid Build Coastguard Worker        self,
615*da0073e9SAndroid Build Coastguard Worker        module,
616*da0073e9SAndroid Build Coastguard Worker        input,
617*da0073e9SAndroid Build Coastguard Worker        args=None,
618*da0073e9SAndroid Build Coastguard Worker        kwargs=None,
619*da0073e9SAndroid Build Coastguard Worker        batch_first=True,
620*da0073e9SAndroid Build Coastguard Worker        atol=None,
621*da0073e9SAndroid Build Coastguard Worker        rtol=None,
622*da0073e9SAndroid Build Coastguard Worker    ):
623*da0073e9SAndroid Build Coastguard Worker        args = args or ()
624*da0073e9SAndroid Build Coastguard Worker        kwargs = kwargs or {}
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        batch_dim = 0 if batch_first else 1
627*da0073e9SAndroid Build Coastguard Worker        batch_size = input.shape[batch_dim]
628*da0073e9SAndroid Build Coastguard Worker        diff_input = input.dtype == torch.float or input.dtype == torch.double
629*da0073e9SAndroid Build Coastguard Worker        if diff_input:
630*da0073e9SAndroid Build Coastguard Worker            input.requires_grad_()
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
633*da0073e9SAndroid Build Coastguard Worker            # get per sample grads with ExpandedWeights context manager
634*da0073e9SAndroid Build Coastguard Worker            actual_res = call_for_per_sample_grads(
635*da0073e9SAndroid Build Coastguard Worker                module,
636*da0073e9SAndroid Build Coastguard Worker                batch_size=batch_size,
637*da0073e9SAndroid Build Coastguard Worker                loss_reduction="sum",
638*da0073e9SAndroid Build Coastguard Worker                batch_first=batch_first,
639*da0073e9SAndroid Build Coastguard Worker            )(input, *args, **kwargs).sum()
640*da0073e9SAndroid Build Coastguard Worker            actual_res.backward()
641*da0073e9SAndroid Build Coastguard Worker            actual_grads = []
642*da0073e9SAndroid Build Coastguard Worker            for param in module.parameters():
643*da0073e9SAndroid Build Coastguard Worker                actual_grads.append(param.grad_sample)
644*da0073e9SAndroid Build Coastguard Worker                del param.grad_sample
645*da0073e9SAndroid Build Coastguard Worker            if diff_input:
646*da0073e9SAndroid Build Coastguard Worker                actual_grads.append(input.grad.clone())
647*da0073e9SAndroid Build Coastguard Worker                input.grad = torch.zeros_like(input.grad)
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker            # get per sample grads with a for loop
650*da0073e9SAndroid Build Coastguard Worker            expected_res = torch.tensor(
651*da0073e9SAndroid Build Coastguard Worker                0.0, device=input.device, dtype=actual_res.dtype
652*da0073e9SAndroid Build Coastguard Worker            )
653*da0073e9SAndroid Build Coastguard Worker            expected_grads = []
654*da0073e9SAndroid Build Coastguard Worker            for i in range(batch_size):
655*da0073e9SAndroid Build Coastguard Worker                input_slice = input.narrow(batch_dim, i, 1)
656*da0073e9SAndroid Build Coastguard Worker                input_slice = input_slice.squeeze(batch_dim)
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker                # h's batch dim is always the first dim. Must be contiguous for CUDA
659*da0073e9SAndroid Build Coastguard Worker                sliced_args = tree_map_only(
660*da0073e9SAndroid Build Coastguard Worker                    torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args
661*da0073e9SAndroid Build Coastguard Worker                )
662*da0073e9SAndroid Build Coastguard Worker                diff_params = module.parameters()
663*da0073e9SAndroid Build Coastguard Worker                if diff_input:
664*da0073e9SAndroid Build Coastguard Worker                    diff_params = chain(diff_params, (input_slice,))
665*da0073e9SAndroid Build Coastguard Worker                res = module(
666*da0073e9SAndroid Build Coastguard Worker                    input_slice.unsqueeze(batch_dim).contiguous(),
667*da0073e9SAndroid Build Coastguard Worker                    *sliced_args,
668*da0073e9SAndroid Build Coastguard Worker                    **kwargs,
669*da0073e9SAndroid Build Coastguard Worker                ).sum()
670*da0073e9SAndroid Build Coastguard Worker                out_grads = torch.autograd.grad(
671*da0073e9SAndroid Build Coastguard Worker                    res, diff_params, torch.ones_like(res), allow_unused=True
672*da0073e9SAndroid Build Coastguard Worker                )
673*da0073e9SAndroid Build Coastguard Worker                expected_grads.append(out_grads)
674*da0073e9SAndroid Build Coastguard Worker                expected_res += res
675*da0073e9SAndroid Build Coastguard Worker            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
676*da0073e9SAndroid Build Coastguard Worker            if not batch_first:
677*da0073e9SAndroid Build Coastguard Worker                expected_grads[-1] = expected_grads[-1].transpose(0, 1)
678*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_res, expected_res)
679*da0073e9SAndroid Build Coastguard Worker        [
680*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, atol=atol, rtol=rtol)
681*da0073e9SAndroid Build Coastguard Worker            for (actual, expected) in zip(actual_grads, expected_grads)
682*da0073e9SAndroid Build Coastguard Worker        ]
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker    def _do_test_multi_input(self, module, input):
685*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
686*da0073e9SAndroid Build Coastguard Worker            def __init__(self, module):
687*da0073e9SAndroid Build Coastguard Worker                super().__init__()
688*da0073e9SAndroid Build Coastguard Worker                self.module = module
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
691*da0073e9SAndroid Build Coastguard Worker                return self.module(input) + self.module(input)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        batch_size = input.shape[0]
694*da0073e9SAndroid Build Coastguard Worker        diff_input = input.dtype == torch.float or input.dtype == torch.double
695*da0073e9SAndroid Build Coastguard Worker        if diff_input:
696*da0073e9SAndroid Build Coastguard Worker            input.requires_grad_()
697*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
698*da0073e9SAndroid Build Coastguard Worker            # get per sample grads with ExpandedWeights context manager, calling .backward() twice
699*da0073e9SAndroid Build Coastguard Worker            test_module = TestModule(module)
700*da0073e9SAndroid Build Coastguard Worker            actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(
701*da0073e9SAndroid Build Coastguard Worker                input
702*da0073e9SAndroid Build Coastguard Worker            ).sum()
703*da0073e9SAndroid Build Coastguard Worker            actual_res.backward()
704*da0073e9SAndroid Build Coastguard Worker            actual_grads = []
705*da0073e9SAndroid Build Coastguard Worker            for param in module.parameters():
706*da0073e9SAndroid Build Coastguard Worker                actual_grads.append(param.grad_sample)
707*da0073e9SAndroid Build Coastguard Worker                del param.grad_sample
708*da0073e9SAndroid Build Coastguard Worker            if diff_input:
709*da0073e9SAndroid Build Coastguard Worker                actual_grads.append(input.grad.clone())
710*da0073e9SAndroid Build Coastguard Worker                input.grad = torch.zeros_like(input.grad)
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker            # get per sample grads with a for loop, running over the input twice
713*da0073e9SAndroid Build Coastguard Worker            expected_grads = []
714*da0073e9SAndroid Build Coastguard Worker            for i in range(batch_size):
715*da0073e9SAndroid Build Coastguard Worker                input_slice = input[i]
716*da0073e9SAndroid Build Coastguard Worker                diff_params = module.parameters()
717*da0073e9SAndroid Build Coastguard Worker                if diff_input:
718*da0073e9SAndroid Build Coastguard Worker                    diff_params = chain(diff_params, (input_slice,))
719*da0073e9SAndroid Build Coastguard Worker                res = module(input_slice.unsqueeze(0)).sum()
720*da0073e9SAndroid Build Coastguard Worker                out_grads = torch.autograd.grad(
721*da0073e9SAndroid Build Coastguard Worker                    res, diff_params, torch.ones_like(res), allow_unused=True
722*da0073e9SAndroid Build Coastguard Worker                )
723*da0073e9SAndroid Build Coastguard Worker                expected_grads.append(out_grads)
724*da0073e9SAndroid Build Coastguard Worker        expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
725*da0073e9SAndroid Build Coastguard Worker        expected_grads = tuple(
726*da0073e9SAndroid Build Coastguard Worker            expected_grad
727*da0073e9SAndroid Build Coastguard Worker            for expected_grad in expected_grads
728*da0073e9SAndroid Build Coastguard Worker            if expected_grad is not None
729*da0073e9SAndroid Build Coastguard Worker        )
730*da0073e9SAndroid Build Coastguard Worker        assert [
731*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, 2 * expected)
732*da0073e9SAndroid Build Coastguard Worker            for (actual, expected) in zip(actual_grads, expected_grads)
733*da0073e9SAndroid Build Coastguard Worker        ]
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def _do_test_rnn_packed_sequence(
736*da0073e9SAndroid Build Coastguard Worker        self, module, input, args=None, kwargs=None, atol=None, rtol=None
737*da0073e9SAndroid Build Coastguard Worker    ):
738*da0073e9SAndroid Build Coastguard Worker        args = args if args is not None else ()
739*da0073e9SAndroid Build Coastguard Worker        kwargs = kwargs if kwargs is not None else {}
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker        batch_size = max(tuple(input.batch_sizes)).item()
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
744*da0073e9SAndroid Build Coastguard Worker            # get per sample grads with ExpandedWeights context manager
745*da0073e9SAndroid Build Coastguard Worker            actual_res = call_for_per_sample_grads(
746*da0073e9SAndroid Build Coastguard Worker                module, batch_size=batch_size, loss_reduction="sum"
747*da0073e9SAndroid Build Coastguard Worker            )(input, *args, **kwargs).data.sum()
748*da0073e9SAndroid Build Coastguard Worker            actual_res.backward()
749*da0073e9SAndroid Build Coastguard Worker            actual_grads = []
750*da0073e9SAndroid Build Coastguard Worker            for param in module.parameters():
751*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(param.grad_sample.shape[0], batch_size)
752*da0073e9SAndroid Build Coastguard Worker                actual_grads.append(param.grad_sample)
753*da0073e9SAndroid Build Coastguard Worker                del param.grad_sample
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker            input.data.grad = torch.zeros_like(input.data)
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker            # compute the per sample grads with a for loop
758*da0073e9SAndroid Build Coastguard Worker            expected_res = torch.zeros_like(actual_res)
759*da0073e9SAndroid Build Coastguard Worker            expected_grads = []
760*da0073e9SAndroid Build Coastguard Worker            padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(
761*da0073e9SAndroid Build Coastguard Worker                input, batch_first=True
762*da0073e9SAndroid Build Coastguard Worker            )
763*da0073e9SAndroid Build Coastguard Worker            for i in range(len(seq_sizes)):
764*da0073e9SAndroid Build Coastguard Worker                input_slice = padded_input[i].narrow(0, 0, seq_sizes[i])
765*da0073e9SAndroid Build Coastguard Worker                diff_params = module.parameters()
766*da0073e9SAndroid Build Coastguard Worker                batch_dim = 0 if module.m.batch_first else 1
767*da0073e9SAndroid Build Coastguard Worker                res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum()
768*da0073e9SAndroid Build Coastguard Worker                expected_res += res
769*da0073e9SAndroid Build Coastguard Worker                out_grads = torch.autograd.grad(
770*da0073e9SAndroid Build Coastguard Worker                    res, diff_params, torch.ones_like(res), allow_unused=True
771*da0073e9SAndroid Build Coastguard Worker                )
772*da0073e9SAndroid Build Coastguard Worker                expected_grads.append(out_grads)
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
775*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_res, expected_res)
776*da0073e9SAndroid Build Coastguard Worker            [
777*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, atol=atol, rtol=rtol)
778*da0073e9SAndroid Build Coastguard Worker                for (actual, expected) in zip(actual_grads, expected_grads)
779*da0073e9SAndroid Build Coastguard Worker            ]
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker    @modules(
782*da0073e9SAndroid Build Coastguard Worker        filter(
783*da0073e9SAndroid Build Coastguard Worker            lambda m_info: m_info.module_cls
784*da0073e9SAndroid Build Coastguard Worker            in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
785*da0073e9SAndroid Build Coastguard Worker            module_db,
786*da0073e9SAndroid Build Coastguard Worker        )
787*da0073e9SAndroid Build Coastguard Worker    )
788*da0073e9SAndroid Build Coastguard Worker    @tf32_off()
789*da0073e9SAndroid Build Coastguard Worker    def test_module(self, device, dtype, module_info, training):
790*da0073e9SAndroid Build Coastguard Worker        class RNNWrapper(torch.nn.Module):
791*da0073e9SAndroid Build Coastguard Worker            def __init__(self, m_cons, args, kwargs):
792*da0073e9SAndroid Build Coastguard Worker                super().__init__()
793*da0073e9SAndroid Build Coastguard Worker                self.m = m_cons(*args, **kwargs)
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker            def forward(self, *inps):
796*da0073e9SAndroid Build Coastguard Worker                ret = self.m(*inps)
797*da0073e9SAndroid Build Coastguard Worker                assert isinstance(ret, tuple)
798*da0073e9SAndroid Build Coastguard Worker                return ret[0]
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker        def batch_hidden(h):
801*da0073e9SAndroid Build Coastguard Worker            new_h_shape = [1] * (len(h.shape) + 1)
802*da0073e9SAndroid Build Coastguard Worker            new_h_shape[1] = 2
803*da0073e9SAndroid Build Coastguard Worker            return h.unsqueeze(1).repeat(new_h_shape)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        module_cls = module_info.module_cls
806*da0073e9SAndroid Build Coastguard Worker        atol, rtol = (
807*da0073e9SAndroid Build Coastguard Worker            (1e-4, 1e-5)
808*da0073e9SAndroid Build Coastguard Worker            if module_cls == torch.nn.GRU and dtype == torch.float32
809*da0073e9SAndroid Build Coastguard Worker            else (None, None)
810*da0073e9SAndroid Build Coastguard Worker        )
811*da0073e9SAndroid Build Coastguard Worker        module_inputs = module_info.module_inputs_func(
812*da0073e9SAndroid Build Coastguard Worker            module_info,
813*da0073e9SAndroid Build Coastguard Worker            device=device,
814*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
815*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
816*da0073e9SAndroid Build Coastguard Worker            training=training,
817*da0073e9SAndroid Build Coastguard Worker            with_packed_sequence=True,
818*da0073e9SAndroid Build Coastguard Worker        )
819*da0073e9SAndroid Build Coastguard Worker        for module_input in module_inputs:
820*da0073e9SAndroid Build Coastguard Worker            if module_input.forward_input is None:
821*da0073e9SAndroid Build Coastguard Worker                continue
822*da0073e9SAndroid Build Coastguard Worker            args, kwargs = (
823*da0073e9SAndroid Build Coastguard Worker                module_input.constructor_input.args,
824*da0073e9SAndroid Build Coastguard Worker                module_input.constructor_input.kwargs,
825*da0073e9SAndroid Build Coastguard Worker            )
826*da0073e9SAndroid Build Coastguard Worker            m = RNNWrapper(module_cls, args, kwargs)
827*da0073e9SAndroid Build Coastguard Worker            batch_first = m.m.batch_first
828*da0073e9SAndroid Build Coastguard Worker            m.to(device).to(dtype)
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker            args, kwargs = (
831*da0073e9SAndroid Build Coastguard Worker                module_input.forward_input.args,
832*da0073e9SAndroid Build Coastguard Worker                module_input.forward_input.kwargs,
833*da0073e9SAndroid Build Coastguard Worker            )
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker            # if the RNN tests use unbatched inputs--batch the inputs
836*da0073e9SAndroid Build Coastguard Worker            input = args[0]
837*da0073e9SAndroid Build Coastguard Worker            if isinstance(input, torch.Tensor) and input.dim() == 2:
838*da0073e9SAndroid Build Coastguard Worker                input = input.detach()
839*da0073e9SAndroid Build Coastguard Worker                new_input_shape = [1] * (len(input.shape) + 1)
840*da0073e9SAndroid Build Coastguard Worker                if batch_first:
841*da0073e9SAndroid Build Coastguard Worker                    new_input_shape[0] = 2
842*da0073e9SAndroid Build Coastguard Worker                    input = input.repeat(new_input_shape)
843*da0073e9SAndroid Build Coastguard Worker                else:
844*da0073e9SAndroid Build Coastguard Worker                    new_input_shape[1] = 2
845*da0073e9SAndroid Build Coastguard Worker                    input = input.unsqueeze(1).repeat(new_input_shape)
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker                h = args[1] if len(args) > 1 else None
848*da0073e9SAndroid Build Coastguard Worker                if h is not None:
849*da0073e9SAndroid Build Coastguard Worker                    h = (
850*da0073e9SAndroid Build Coastguard Worker                        batch_hidden(h)
851*da0073e9SAndroid Build Coastguard Worker                        if isinstance(h, torch.Tensor)
852*da0073e9SAndroid Build Coastguard Worker                        else tuple(batch_hidden(hx) for hx in h)
853*da0073e9SAndroid Build Coastguard Worker                    )
854*da0073e9SAndroid Build Coastguard Worker                    args = list(args)
855*da0073e9SAndroid Build Coastguard Worker                    args[1] = h
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker            if isinstance(input, torch.nn.utils.rnn.PackedSequence):
858*da0073e9SAndroid Build Coastguard Worker                self._do_test_rnn_packed_sequence(
859*da0073e9SAndroid Build Coastguard Worker                    m, input, args[1:], kwargs, atol=atol, rtol=rtol
860*da0073e9SAndroid Build Coastguard Worker                )
861*da0073e9SAndroid Build Coastguard Worker            else:
862*da0073e9SAndroid Build Coastguard Worker                self._do_test(
863*da0073e9SAndroid Build Coastguard Worker                    m,
864*da0073e9SAndroid Build Coastguard Worker                    input,
865*da0073e9SAndroid Build Coastguard Worker                    args[1:],
866*da0073e9SAndroid Build Coastguard Worker                    kwargs,
867*da0073e9SAndroid Build Coastguard Worker                    batch_first=batch_first,
868*da0073e9SAndroid Build Coastguard Worker                    atol=atol,
869*da0073e9SAndroid Build Coastguard Worker                    rtol=rtol,
870*da0073e9SAndroid Build Coastguard Worker                )
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker    def test_per_sample_api_failing(self):
873*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(10, 10)
874*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(64, 10)
875*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
876*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads("fail")(input)
877*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
878*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Batch size passed must be None or an integer"
879*da0073e9SAndroid Build Coastguard Worker        ):
880*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads(module, batch_size=6.4)(input)
881*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
882*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads(module, batch_size=-64)(input)
883*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
884*da0073e9SAndroid Build Coastguard Worker            loss = call_for_per_sample_grads(module)(input).sum()
885*da0073e9SAndroid Build Coastguard Worker            loss.backward()  # populate grad_sample fields
886*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads(module)(input)
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(10, 10)  # reset to not have grad_sample fields
889*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
890*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expected loss_reduction argument to be sum or mean"
891*da0073e9SAndroid Build Coastguard Worker        ):
892*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads(module, loss_reduction="")(input)
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker    def test_per_sample_api_compute_batch_size(self):
895*da0073e9SAndroid Build Coastguard Worker        class CustomModule(nn.Module):
896*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
897*da0073e9SAndroid Build Coastguard Worker                super().__init__()
898*da0073e9SAndroid Build Coastguard Worker                self.linear = nn.Linear(5, 5)
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker            def forward(self, input1, input2):
901*da0073e9SAndroid Build Coastguard Worker                return self.linear(input1) + self.linear(input2)
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker        module = CustomModule()
904*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(4, 5)
905*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(5, 5)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
908*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
909*da0073e9SAndroid Build Coastguard Worker            "found at least one input with batch size 4 and one with batch size 5",
910*da0073e9SAndroid Build Coastguard Worker        ):
911*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads(module)(input1, input2)
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(4, 5)
914*da0073e9SAndroid Build Coastguard Worker        call_for_per_sample_grads(module)(input1, input2)
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker        module = CustomModule()
917*da0073e9SAndroid Build Coastguard Worker        call_for_per_sample_grads(module)(input1, input2=input2)
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker        module = CustomModule()
920*da0073e9SAndroid Build Coastguard Worker        call_for_per_sample_grads(module)(input1=input1, input2=input2)
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker    def test_per_sample_api_compute_batch_size_not_pytreeable(self):
923*da0073e9SAndroid Build Coastguard Worker        @dataclass
924*da0073e9SAndroid Build Coastguard Worker        class NonPytreeableTuple:
925*da0073e9SAndroid Build Coastguard Worker            elem1: torch.Tensor
926*da0073e9SAndroid Build Coastguard Worker            elem2: torch.Tensor
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker        class CustomModule(nn.Module):
929*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
930*da0073e9SAndroid Build Coastguard Worker                super().__init__()
931*da0073e9SAndroid Build Coastguard Worker                self.linear = nn.Linear(5, 5)
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker            def forward(self, input1, input2):
934*da0073e9SAndroid Build Coastguard Worker                return self.linear(input1.elem1) + self.linear(input1.elem2)
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5))
937*da0073e9SAndroid Build Coastguard Worker        model = CustomModule()
938*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
939*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
940*da0073e9SAndroid Build Coastguard Worker            "ExpandedWeights cannot compute the batch size from the inputs",
941*da0073e9SAndroid Build Coastguard Worker        ):
942*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads(model)(input, "")
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        # would prefer for it to error because input is not pytree-able but that's hard to detect
945*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
946*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected ExpandedWeights to have batch size matching input"
947*da0073e9SAndroid Build Coastguard Worker        ):
948*da0073e9SAndroid Build Coastguard Worker            call_for_per_sample_grads(model)(input, torch.randn(5))
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker        model = CustomModule()  # TODO: functional call bug, sam will fix
951*da0073e9SAndroid Build Coastguard Worker        call_for_per_sample_grads(model)(input, torch.randn(4, 5))
952*da0073e9SAndroid Build Coastguard Worker        model = CustomModule()
953*da0073e9SAndroid Build Coastguard Worker        call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Workerclass ContextManagerTests(TestBase):
957*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args, **kwargs):
958*da0073e9SAndroid Build Coastguard Worker        self.test_cpu = kwargs.get("test_cpu", True)
959*da0073e9SAndroid Build Coastguard Worker        self.test_cuda = kwargs.get("test_cuda", True)
960*da0073e9SAndroid Build Coastguard Worker        super().__init__(*args, **kwargs)
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker    @property
963*da0073e9SAndroid Build Coastguard Worker    def constructor_args(self):
964*da0073e9SAndroid Build Coastguard Worker        return self._get_arg("constructor_args", False)
965*da0073e9SAndroid Build Coastguard Worker
966*da0073e9SAndroid Build Coastguard Worker    def test_context_manager(self, test_case, device):
967*da0073e9SAndroid Build Coastguard Worker        kwargs = {"device": device, "dtype": torch.double}
968*da0073e9SAndroid Build Coastguard Worker        module = self.constructor(*self.constructor_args).to(**kwargs)
969*da0073e9SAndroid Build Coastguard Worker        if "Embedding" in self.get_name():
970*da0073e9SAndroid Build Coastguard Worker            kwargs["dtype"] = torch.long
971*da0073e9SAndroid Build Coastguard Worker        input = self._get_input().to(**kwargs)
972*da0073e9SAndroid Build Coastguard Worker        if len(input.shape) == 0 or input.shape[0] == 0:
973*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest(
974*da0073e9SAndroid Build Coastguard Worker                "Can't get per sample gradients when no batch dim or batch dim is 0"
975*da0073e9SAndroid Build Coastguard Worker            )
976*da0073e9SAndroid Build Coastguard Worker        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
977*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest(
978*da0073e9SAndroid Build Coastguard Worker                "Can't get per sample gradients for input of rank 1"
979*da0073e9SAndroid Build Coastguard Worker            )
980*da0073e9SAndroid Build Coastguard Worker        test_case._do_test(module, input)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker    def test_context_manager_multiple_inputs(self, test_case, device):
983*da0073e9SAndroid Build Coastguard Worker        module = self.constructor(*self.constructor_args).to(device)
984*da0073e9SAndroid Build Coastguard Worker        input = self._get_input()
985*da0073e9SAndroid Build Coastguard Worker        if len(input.shape) == 0 or input.shape[0] == 0:
986*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest(
987*da0073e9SAndroid Build Coastguard Worker                "Can't get per sample gradients when no batch dim or batch dim is 0"
988*da0073e9SAndroid Build Coastguard Worker            )
989*da0073e9SAndroid Build Coastguard Worker        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
990*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest(
991*da0073e9SAndroid Build Coastguard Worker                "Can't get per sample gradients for input of rank 1"
992*da0073e9SAndroid Build Coastguard Worker            )
993*da0073e9SAndroid Build Coastguard Worker        test_case._do_test_multi_input(module, input)
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Workerdef filter_supported_tests(t):
997*da0073e9SAndroid Build Coastguard Worker    supported_modules = [
998*da0073e9SAndroid Build Coastguard Worker        "Linear",
999*da0073e9SAndroid Build Coastguard Worker        "Conv1d",
1000*da0073e9SAndroid Build Coastguard Worker        "Conv2d",
1001*da0073e9SAndroid Build Coastguard Worker        "Conv3d",
1002*da0073e9SAndroid Build Coastguard Worker        "Embedding",
1003*da0073e9SAndroid Build Coastguard Worker        "LayerNorm",
1004*da0073e9SAndroid Build Coastguard Worker        "GroupNorm",
1005*da0073e9SAndroid Build Coastguard Worker        "InstanceNorm",
1006*da0073e9SAndroid Build Coastguard Worker    ]
1007*da0073e9SAndroid Build Coastguard Worker    if "module_name" in t and t["module_name"] in supported_modules:
1008*da0073e9SAndroid Build Coastguard Worker        return True
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
1012*da0073e9SAndroid Build Coastguard Worker# These currently use the legacy nn tests
1013*da0073e9SAndroid Build Coastguard Workersupported_tests = [
1014*da0073e9SAndroid Build Coastguard Worker    t for t in module_tests + new_module_tests if filter_supported_tests(t)
1015*da0073e9SAndroid Build Coastguard Worker]
1016*da0073e9SAndroid Build Coastguard Workerfor test_param in supported_tests:
1017*da0073e9SAndroid Build Coastguard Worker    if "constructor" not in test_param:
1018*da0073e9SAndroid Build Coastguard Worker        name = test_param.pop("module_name")
1019*da0073e9SAndroid Build Coastguard Worker        test_param["constructor"] = getattr(nn, name)
1020*da0073e9SAndroid Build Coastguard Worker    decorator = test_param.pop("decorator", lambda test: test)
1021*da0073e9SAndroid Build Coastguard Worker    test = ContextManagerTests(**test_param)
1022*da0073e9SAndroid Build Coastguard Worker    test_name = test.get_name()
1023*da0073e9SAndroid Build Coastguard Worker    if hasattr(TestExpandedWeightModule, test_name):
1024*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Found two tests with the same name: " + test_name)
1025*da0073e9SAndroid Build Coastguard Worker    test_name_multi_input = test.get_name() + "_multiple_inputs"
1026*da0073e9SAndroid Build Coastguard Worker    if hasattr(TestExpandedWeightModule, test_name_multi_input):
1027*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Found two tests with the same name: " + test_name)
1028*da0073e9SAndroid Build Coastguard Worker    if test.test_cpu:
1029*da0073e9SAndroid Build Coastguard Worker        setattr(
1030*da0073e9SAndroid Build Coastguard Worker            TestExpandedWeightModule,
1031*da0073e9SAndroid Build Coastguard Worker            test_name,
1032*da0073e9SAndroid Build Coastguard Worker            decorator(lambda self, test=test: test.test_context_manager(self, "cpu")),
1033*da0073e9SAndroid Build Coastguard Worker        )
1034*da0073e9SAndroid Build Coastguard Worker        setattr(
1035*da0073e9SAndroid Build Coastguard Worker            TestExpandedWeightModule,
1036*da0073e9SAndroid Build Coastguard Worker            test_name_multi_input,
1037*da0073e9SAndroid Build Coastguard Worker            decorator(
1038*da0073e9SAndroid Build Coastguard Worker                lambda self, test=test: test.test_context_manager_multiple_inputs(
1039*da0073e9SAndroid Build Coastguard Worker                    self, "cpu"
1040*da0073e9SAndroid Build Coastguard Worker                )
1041*da0073e9SAndroid Build Coastguard Worker            ),
1042*da0073e9SAndroid Build Coastguard Worker        )
1043*da0073e9SAndroid Build Coastguard Worker    if TEST_CUDA and test.test_cuda:
1044*da0073e9SAndroid Build Coastguard Worker        # since this checks derivatives, only use double for precision
1045*da0073e9SAndroid Build Coastguard Worker        setattr(
1046*da0073e9SAndroid Build Coastguard Worker            TestExpandedWeightModule,
1047*da0073e9SAndroid Build Coastguard Worker            test_name + "_cuda_double",
1048*da0073e9SAndroid Build Coastguard Worker            decorator(lambda self, test=test: test.test_context_manager(self, "cuda")),
1049*da0073e9SAndroid Build Coastguard Worker        )
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker# ------------- HELPER FUNCTIONS -----------------
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Workerdef run_op(op, input, *args, **kwargs):
1055*da0073e9SAndroid Build Coastguard Worker    r"""
1056*da0073e9SAndroid Build Coastguard Worker    OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative
1057*da0073e9SAndroid Build Coastguard Worker    of the weight, not the input, which can't be differentiable since its dtype is int. Calls op,
1058*da0073e9SAndroid Build Coastguard Worker    using the special ordering that Embedding's OpInfo expects for that case.
1059*da0073e9SAndroid Build Coastguard Worker    """
1060*da0073e9SAndroid Build Coastguard Worker    if op.name == "nn.functional.embedding":
1061*da0073e9SAndroid Build Coastguard Worker        return op(args[0], input, **kwargs)
1062*da0073e9SAndroid Build Coastguard Worker    else:
1063*da0073e9SAndroid Build Coastguard Worker        return op(input, *args, **kwargs)
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Workerdef make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
1067*da0073e9SAndroid Build Coastguard Worker    def expanded_weight_or_clone(arg):
1068*da0073e9SAndroid Build Coastguard Worker        if is_diff_tensor(arg):
1069*da0073e9SAndroid Build Coastguard Worker            return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
1070*da0073e9SAndroid Build Coastguard Worker        return clone_if_tensor(arg)
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker    ew_input = clone_if_tensor(sample_input.input)
1073*da0073e9SAndroid Build Coastguard Worker    ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
1074*da0073e9SAndroid Build Coastguard Worker    ew_kwargs = {
1075*da0073e9SAndroid Build Coastguard Worker        name: expanded_weight_or_clone(arg)
1076*da0073e9SAndroid Build Coastguard Worker        for (name, arg) in sample_input.kwargs.items()
1077*da0073e9SAndroid Build Coastguard Worker    }
1078*da0073e9SAndroid Build Coastguard Worker    return ew_input, ew_args, ew_kwargs
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Workerdef supported_inputs(op, sample_inputs, supported_inputs=True):
1082*da0073e9SAndroid Build Coastguard Worker    r"""
1083*da0073e9SAndroid Build Coastguard Worker    ExpandedWeights currently does not support some use cases when there's no batch dimension or
1084*da0073e9SAndroid Build Coastguard Worker    operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
1085*da0073e9SAndroid Build Coastguard Worker    """
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Worker    def filter_fn(input):
1088*da0073e9SAndroid Build Coastguard Worker        convolutions = [
1089*da0073e9SAndroid Build Coastguard Worker            "nn.functional.conv1d",
1090*da0073e9SAndroid Build Coastguard Worker            "nn.functional.conv2d",
1091*da0073e9SAndroid Build Coastguard Worker            "nn.functional.conv3d",
1092*da0073e9SAndroid Build Coastguard Worker        ]
1093*da0073e9SAndroid Build Coastguard Worker        batched_input_size = dict(zip(convolutions, [3, 4, 5]))
1094*da0073e9SAndroid Build Coastguard Worker        if op.name == "nn.functional.linear":
1095*da0073e9SAndroid Build Coastguard Worker            is_supported_input = (
1096*da0073e9SAndroid Build Coastguard Worker                input.input.dim() > 1
1097*da0073e9SAndroid Build Coastguard Worker            )  # input of rank 1 means no batch dim
1098*da0073e9SAndroid Build Coastguard Worker        elif op.name == "nn.functional.layer_norm":
1099*da0073e9SAndroid Build Coastguard Worker            normalized_shape = input.args[0]
1100*da0073e9SAndroid Build Coastguard Worker            is_supported_input = (
1101*da0073e9SAndroid Build Coastguard Worker                input.input.shape != normalized_shape
1102*da0073e9SAndroid Build Coastguard Worker            )  # would cause inter-batch operations
1103*da0073e9SAndroid Build Coastguard Worker        elif op.name in convolutions:
1104*da0073e9SAndroid Build Coastguard Worker            # currently can't deal with padding computation on Python level
1105*da0073e9SAndroid Build Coastguard Worker            is_supported_input = input.input.dim() == batched_input_size[op.name]
1106*da0073e9SAndroid Build Coastguard Worker        elif op.name == "nn.functional.embedding":
1107*da0073e9SAndroid Build Coastguard Worker            idx = input.args[0]
1108*da0073e9SAndroid Build Coastguard Worker            is_supported_input = len(idx.shape) > 1  # there's no batch size
1109*da0073e9SAndroid Build Coastguard Worker        else:
1110*da0073e9SAndroid Build Coastguard Worker            is_supported_input = True
1111*da0073e9SAndroid Build Coastguard Worker        is_supported_input = (
1112*da0073e9SAndroid Build Coastguard Worker            is_supported_input and input.input.shape[0] > 0
1113*da0073e9SAndroid Build Coastguard Worker        )  # 0 is not a valid batch size
1114*da0073e9SAndroid Build Coastguard Worker        return is_supported_input if supported_inputs else not is_supported_input
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker    return [input for input in sample_inputs if filter_fn(input)]
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Workerdef for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
1120*da0073e9SAndroid Build Coastguard Worker    # get per sample grads by getting derivative for each input in a for loop
1121*da0073e9SAndroid Build Coastguard Worker    per_sample_grad = []
1122*da0073e9SAndroid Build Coastguard Worker    for i in range(batch_size):
1123*da0073e9SAndroid Build Coastguard Worker        per_sample_input = input[i]
1124*da0073e9SAndroid Build Coastguard Worker        result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
1125*da0073e9SAndroid Build Coastguard Worker        diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
1126*da0073e9SAndroid Build Coastguard Worker        diff_input_list = [
1127*da0073e9SAndroid Build Coastguard Worker            i
1128*da0073e9SAndroid Build Coastguard Worker            for i in diff_input_list
1129*da0073e9SAndroid Build Coastguard Worker            if isinstance(i, torch.Tensor) and i.requires_grad
1130*da0073e9SAndroid Build Coastguard Worker        ]
1131*da0073e9SAndroid Build Coastguard Worker        per_sample_grad.append(
1132*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
1133*da0073e9SAndroid Build Coastguard Worker                result, diff_input_list, torch.ones_like(result), allow_unused=True
1134*da0073e9SAndroid Build Coastguard Worker            )
1135*da0073e9SAndroid Build Coastguard Worker        )
1136*da0073e9SAndroid Build Coastguard Worker    if len(per_sample_grad) == batch_size:
1137*da0073e9SAndroid Build Coastguard Worker        per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad))
1138*da0073e9SAndroid Build Coastguard Worker    return per_sample_grad
1139*da0073e9SAndroid Build Coastguard Worker
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Workerdef is_diff_tensor(t):
1142*da0073e9SAndroid Build Coastguard Worker    return isinstance(t, ExpandedWeight) or (
1143*da0073e9SAndroid Build Coastguard Worker        isinstance(t, torch.Tensor) and t.requires_grad
1144*da0073e9SAndroid Build Coastguard Worker    )
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Workerdef clone_if_tensor(t):
1148*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, torch.Tensor):
1149*da0073e9SAndroid Build Coastguard Worker        res = torch.clone(t).detach()
1150*da0073e9SAndroid Build Coastguard Worker        res.requires_grad_(t.requires_grad)
1151*da0073e9SAndroid Build Coastguard Worker        return res
1152*da0073e9SAndroid Build Coastguard Worker    else:
1153*da0073e9SAndroid Build Coastguard Worker        return t
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
1157*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestExpandedWeightFunctional, globals())
1158*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestExpandedWeightModule, globals())
1159*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1160*da0073e9SAndroid Build Coastguard Worker    run_tests()
1161