1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Workerimport unittest 3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 5*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, product 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import CrossEntropyLoss 11*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils._expanded_weights import ExpandedWeight 12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils._expanded_weights.expanded_weights_utils import ( 13*da0073e9SAndroid Build Coastguard Worker forward_helper, 14*da0073e9SAndroid Build Coastguard Worker set_grad_sample_if_exists, 15*da0073e9SAndroid Build Coastguard Worker standard_kwargs, 16*da0073e9SAndroid Build Coastguard Worker sum_over_all_but_batch_and_last_n, 17*da0073e9SAndroid Build Coastguard Worker unpack_expanded_weight_or_tensor, 18*da0073e9SAndroid Build Coastguard Worker) 19*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils._per_sample_grad import call_for_per_sample_grads 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, tf32_off 21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 22*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 23*da0073e9SAndroid Build Coastguard Worker OpDTypes, 24*da0073e9SAndroid Build Coastguard Worker ops, 25*da0073e9SAndroid Build Coastguard Worker) 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db, SampleInput 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_modules import module_db, modules 28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase 29*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 30*da0073e9SAndroid Build Coastguard Worker freeze_rng_state, 31*da0073e9SAndroid Build Coastguard Worker make_tensor, 32*da0073e9SAndroid Build Coastguard Worker parametrize, 33*da0073e9SAndroid Build Coastguard Worker run_tests, 34*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 35*da0073e9SAndroid Build Coastguard Worker TestCase, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map_only 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerclass TestContext: 41*da0073e9SAndroid Build Coastguard Worker pass 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workerclass TestExpandedWeightHelperFunction(TestCase): 45*da0073e9SAndroid Build Coastguard Worker def test_forward_helper(self, device): 46*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4, device=device) 47*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(5, 4, device=device) 48*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(5, device=device) 49*da0073e9SAndroid Build Coastguard Worker for weight_batched, bias_batched in product([True, False], [True, False]): 50*da0073e9SAndroid Build Coastguard Worker maybe_batched_weight = weight 51*da0073e9SAndroid Build Coastguard Worker maybe_batched_bias = bias 52*da0073e9SAndroid Build Coastguard Worker if weight_batched: 53*da0073e9SAndroid Build Coastguard Worker maybe_batched_weight = ExpandedWeight( 54*da0073e9SAndroid Build Coastguard Worker weight.clone().requires_grad_(), 3, loss_reduction="sum" 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker if bias_batched: 57*da0073e9SAndroid Build Coastguard Worker maybe_batched_bias = ExpandedWeight( 58*da0073e9SAndroid Build Coastguard Worker bias.clone().requires_grad_(), 3, loss_reduction="sum" 59*da0073e9SAndroid Build Coastguard Worker ) 60*da0073e9SAndroid Build Coastguard Worker args = (input, maybe_batched_weight, maybe_batched_bias) 61*da0073e9SAndroid Build Coastguard Worker expanded_args, expanded_kwargs = standard_kwargs(("bias",), args) 62*da0073e9SAndroid Build Coastguard Worker res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 63*da0073e9SAndroid Build Coastguard Worker expected = nn.functional.linear(input, weight, bias) 64*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(expanded_args), 2) 67*da0073e9SAndroid Build Coastguard Worker assert expanded_args[0] is args[0] # avoids property checks in assertEquals 68*da0073e9SAndroid Build Coastguard Worker assert expanded_args[1] is args[1] # avoids property checks in assertEquals 69*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(expanded_kwargs), 1) 70*da0073e9SAndroid Build Coastguard Worker assert ( 71*da0073e9SAndroid Build Coastguard Worker expanded_kwargs["bias"] is args[2] 72*da0073e9SAndroid Build Coastguard Worker ) # avoids property checks in assertEquals 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def test_forward_helper_failure_args(self, device): 75*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(5, 4, device=device) 76*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(5, device=device) 77*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 78*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"do not support inputs that are also ExpandedWeights." 79*da0073e9SAndroid Build Coastguard Worker ): 80*da0073e9SAndroid Build Coastguard Worker input = ExpandedWeight( 81*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum" 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker expanded_args, expanded_kwargs = standard_kwargs( 84*da0073e9SAndroid Build Coastguard Worker ("bias",), (input, weight, bias) 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 87*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 88*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"requires a Tensor as the first input" 89*da0073e9SAndroid Build Coastguard Worker ): 90*da0073e9SAndroid Build Coastguard Worker expanded_args, expanded_kwargs = standard_kwargs( 91*da0073e9SAndroid Build Coastguard Worker ("bias",), (3, weight, bias) 92*da0073e9SAndroid Build Coastguard Worker ) 93*da0073e9SAndroid Build Coastguard Worker forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 94*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 95*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"requires a batch dimension but got an input of size 0" 96*da0073e9SAndroid Build Coastguard Worker ): 97*da0073e9SAndroid Build Coastguard Worker expanded_args, expanded_kwargs = standard_kwargs( 98*da0073e9SAndroid Build Coastguard Worker ("bias",), (torch.tensor(3), weight, bias) 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 101*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 102*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"0 is not a valid batch size for Expanded Weights" 103*da0073e9SAndroid Build Coastguard Worker ): 104*da0073e9SAndroid Build Coastguard Worker expanded_args, expanded_kwargs = standard_kwargs( 105*da0073e9SAndroid Build Coastguard Worker ("bias",), (torch.randn(0, 1, 2), weight, bias) 106*da0073e9SAndroid Build Coastguard Worker ) 107*da0073e9SAndroid Build Coastguard Worker forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 108*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 109*da0073e9SAndroid Build Coastguard Worker for weight_batched, bias_batched in product([True, False], [True, False]): 110*da0073e9SAndroid Build Coastguard Worker if not weight_batched and not bias_batched: 111*da0073e9SAndroid Build Coastguard Worker continue 112*da0073e9SAndroid Build Coastguard Worker maybe_batched_weight = weight 113*da0073e9SAndroid Build Coastguard Worker maybe_batched_bias = bias 114*da0073e9SAndroid Build Coastguard Worker if weight_batched: 115*da0073e9SAndroid Build Coastguard Worker maybe_batched_weight = ExpandedWeight( 116*da0073e9SAndroid Build Coastguard Worker weight.clone().requires_grad_(), 4, loss_reduction="sum" 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker if bias_batched: 119*da0073e9SAndroid Build Coastguard Worker maybe_batched_bias = ExpandedWeight( 120*da0073e9SAndroid Build Coastguard Worker bias.clone().requires_grad_(), 4, loss_reduction="sum" 121*da0073e9SAndroid Build Coastguard Worker ) 122*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 123*da0073e9SAndroid Build Coastguard Worker RuntimeError, 124*da0073e9SAndroid Build Coastguard Worker r"Expected ExpandedWeights to have batch size matching input", 125*da0073e9SAndroid Build Coastguard Worker ): 126*da0073e9SAndroid Build Coastguard Worker expanded_args, expanded_kwargs = standard_kwargs( 127*da0073e9SAndroid Build Coastguard Worker ("bias",), (input, maybe_batched_weight, maybe_batched_bias) 128*da0073e9SAndroid Build Coastguard Worker ) 129*da0073e9SAndroid Build Coastguard Worker forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def test_set_grad_sample_if_exists(self, device): 132*da0073e9SAndroid Build Coastguard Worker def test_fn(a): 133*da0073e9SAndroid Build Coastguard Worker return grad_sample 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker orig_weight = torch.randn(4, device=device, requires_grad=True) 136*da0073e9SAndroid Build Coastguard Worker expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum") 137*da0073e9SAndroid Build Coastguard Worker grad_sample = torch.randn(3) 138*da0073e9SAndroid Build Coastguard Worker set_grad_sample_if_exists(expanded_weight, test_fn) 139*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(orig_weight, "grad_sample")) 140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_weight.grad_sample, grad_sample) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker basic_tensor = torch.randn(4, device=device) 143*da0073e9SAndroid Build Coastguard Worker set_grad_sample_if_exists(basic_tensor, test_fn) 144*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(basic_tensor, "grad_sample")) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker non_tensor = 3 147*da0073e9SAndroid Build Coastguard Worker set_grad_sample_if_exists(non_tensor, test_fn) 148*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(non_tensor, "grad_sample")) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker def test_set_grad_sample_if_exists_failure(self, device): 151*da0073e9SAndroid Build Coastguard Worker def test_fn(a): 152*da0073e9SAndroid Build Coastguard Worker return True 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker grad_tensor = torch.randn(4, requires_grad=True, device=device) 155*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 156*da0073e9SAndroid Build Coastguard Worker RuntimeError, 157*da0073e9SAndroid Build Coastguard Worker r"does not support a mixture of ExpandedWeight parameters and normal Parameters", 158*da0073e9SAndroid Build Coastguard Worker ): 159*da0073e9SAndroid Build Coastguard Worker set_grad_sample_if_exists(grad_tensor, test_fn) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def test_unpack_expanded_weight_or_tensor(self, device): 162*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, requires_grad=True, device=device) 163*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 164*da0073e9SAndroid Build Coastguard Worker input, 165*da0073e9SAndroid Build Coastguard Worker unpack_expanded_weight_or_tensor( 166*da0073e9SAndroid Build Coastguard Worker ExpandedWeight(input, 3, loss_reduction="sum") 167*da0073e9SAndroid Build Coastguard Worker ), 168*da0073e9SAndroid Build Coastguard Worker ) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker input.requires_grad_(False) 171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, unpack_expanded_weight_or_tensor(input)) 172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(unpack_expanded_weight_or_tensor(4) is None) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device): 175*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, requires_grad=True, device=device) 176*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 177*da0073e9SAndroid Build Coastguard Worker unpack_expanded_weight_or_tensor( 178*da0073e9SAndroid Build Coastguard Worker ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker ) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker input.requires_grad_(False) 183*da0073e9SAndroid Build Coastguard Worker self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input)) 184*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 185*da0073e9SAndroid Build Coastguard Worker unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None 186*da0073e9SAndroid Build Coastguard Worker ) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def test_unpack_expanded_weight_or_tensor_failure(self, device): 189*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, requires_grad=True, device=device) 190*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 191*da0073e9SAndroid Build Coastguard Worker RuntimeError, 192*da0073e9SAndroid Build Coastguard Worker r"does not support a mixture of ExpandedWeight parameters and normal Parameters", 193*da0073e9SAndroid Build Coastguard Worker ): 194*da0073e9SAndroid Build Coastguard Worker unpack_expanded_weight_or_tensor(input) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 197*da0073e9SAndroid Build Coastguard Worker RuntimeError, 198*da0073e9SAndroid Build Coastguard Worker r"does not support a mixture of ExpandedWeight parameters and normal Parameters", 199*da0073e9SAndroid Build Coastguard Worker ): 200*da0073e9SAndroid Build Coastguard Worker unpack_expanded_weight_or_tensor(input, lambda x: x is input) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def test_sum_over_all_but_batch_and_last_n(self, device): 203*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 3, 4, 5, device=device) 204*da0073e9SAndroid Build Coastguard Worker res = sum_over_all_but_batch_and_last_n(input, 2) 205*da0073e9SAndroid Build Coastguard Worker expected = input.sum((1, 2)) 206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker res = sum_over_all_but_batch_and_last_n(input, 0) 209*da0073e9SAndroid Build Coastguard Worker expected = input.sum((1, 2, 3, 4)) 210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker res = sum_over_all_but_batch_and_last_n(input, 4) 213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, input) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Workerclass TestExpandedWeightFunctional(TestCase): 217*da0073e9SAndroid Build Coastguard Worker def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction): 218*da0073e9SAndroid Build Coastguard Worker input = sample_input.input 219*da0073e9SAndroid Build Coastguard Worker args = sample_input.args 220*da0073e9SAndroid Build Coastguard Worker kwargs = sample_input.kwargs 221*da0073e9SAndroid Build Coastguard Worker batch_size = input.shape[0] if len(input.shape) > 1 else 1 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker # get per sample grads with ExpandedWeights objects 224*da0073e9SAndroid Build Coastguard Worker loss_reduction = "sum" if reduction == torch.sum else "mean" 225*da0073e9SAndroid Build Coastguard Worker (ew_input, ew_args, ew_kwargs) = make_expanded_weight( 226*da0073e9SAndroid Build Coastguard Worker sample_input, batch_size, loss_reduction 227*da0073e9SAndroid Build Coastguard Worker ) 228*da0073e9SAndroid Build Coastguard Worker diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) 229*da0073e9SAndroid Build Coastguard Worker diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] 230*da0073e9SAndroid Build Coastguard Worker diff_input_list = [ 231*da0073e9SAndroid Build Coastguard Worker i.orig_weight if isinstance(i, ExpandedWeight) else i 232*da0073e9SAndroid Build Coastguard Worker for i in diff_input_list 233*da0073e9SAndroid Build Coastguard Worker ] 234*da0073e9SAndroid Build Coastguard Worker if not diff_input_list: 235*da0073e9SAndroid Build Coastguard Worker return 236*da0073e9SAndroid Build Coastguard Worker result = run_op(op, ew_input, *ew_args, **ew_kwargs) 237*da0073e9SAndroid Build Coastguard Worker reduction( 238*da0073e9SAndroid Build Coastguard Worker result 239*da0073e9SAndroid Build Coastguard Worker ).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ 240*da0073e9SAndroid Build Coastguard Worker expanded_weight_grad = tuple( 241*da0073e9SAndroid Build Coastguard Worker i.grad_sample if hasattr(i, "grad_sample") else i.grad 242*da0073e9SAndroid Build Coastguard Worker for i in diff_input_list 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker # get per sample grads with for loop 246*da0073e9SAndroid Build Coastguard Worker func = partial(run_op, op) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker per_sample_grad = for_loop_per_sample_grad( 249*da0073e9SAndroid Build Coastguard Worker batch_size, reduction, input, func, *args, **kwargs 250*da0073e9SAndroid Build Coastguard Worker ) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker # check equality 253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(per_sample_grad), len(expanded_weight_grad)) 254*da0073e9SAndroid Build Coastguard Worker if loss_reduction == "mean": 255*da0073e9SAndroid Build Coastguard Worker # don't check equality of `input.grad`s since these vanilla tensors won't be scaled 256*da0073e9SAndroid Build Coastguard Worker expanded_weight_grad = expanded_weight_grad[1:] 257*da0073e9SAndroid Build Coastguard Worker per_sample_grad = per_sample_grad[1:] 258*da0073e9SAndroid Build Coastguard Worker for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad): 259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_grad, expected_grad) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker @ops( 262*da0073e9SAndroid Build Coastguard Worker filter(lambda op: op.supports_expanded_weight, op_db), 263*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.supported, 264*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.double,), 265*da0073e9SAndroid Build Coastguard Worker ) 266*da0073e9SAndroid Build Coastguard Worker def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op): 267*da0073e9SAndroid Build Coastguard Worker sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 268*da0073e9SAndroid Build Coastguard Worker for sample_input in supported_inputs(op, sample_inputs): 269*da0073e9SAndroid Build Coastguard Worker if ( 270*da0073e9SAndroid Build Coastguard Worker op.name == "nn.functional.embedding" 271*da0073e9SAndroid Build Coastguard Worker ): # embedding flips its argument order for autograd tests 272*da0073e9SAndroid Build Coastguard Worker sample_input = SampleInput( 273*da0073e9SAndroid Build Coastguard Worker sample_input.args[0], 274*da0073e9SAndroid Build Coastguard Worker args=(sample_input.input,), 275*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input.kwargs, 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker @ops( 281*da0073e9SAndroid Build Coastguard Worker filter(lambda op: op.supports_expanded_weight, op_db), 282*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.supported, 283*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.double,), 284*da0073e9SAndroid Build Coastguard Worker ) 285*da0073e9SAndroid Build Coastguard Worker def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op): 286*da0073e9SAndroid Build Coastguard Worker sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 287*da0073e9SAndroid Build Coastguard Worker for sample_input in supported_inputs(op, sample_inputs): 288*da0073e9SAndroid Build Coastguard Worker if ( 289*da0073e9SAndroid Build Coastguard Worker op.name == "nn.functional.embedding" 290*da0073e9SAndroid Build Coastguard Worker ): # embedding flips its argument order for autograd tests 291*da0073e9SAndroid Build Coastguard Worker sample_input = SampleInput( 292*da0073e9SAndroid Build Coastguard Worker sample_input.args[0], 293*da0073e9SAndroid Build Coastguard Worker args=(sample_input.input,), 294*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input.kwargs, 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker @ops( 300*da0073e9SAndroid Build Coastguard Worker filter(lambda op: op.supports_expanded_weight, op_db), 301*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.supported, 302*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.double,), 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op): 305*da0073e9SAndroid Build Coastguard Worker sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 306*da0073e9SAndroid Build Coastguard Worker for sample_input in supported_inputs(op, sample_inputs): 307*da0073e9SAndroid Build Coastguard Worker if ( 308*da0073e9SAndroid Build Coastguard Worker op.name == "nn.functional.embedding" 309*da0073e9SAndroid Build Coastguard Worker ): # embedding flips its argument order for autograd tests 310*da0073e9SAndroid Build Coastguard Worker sample_input = SampleInput( 311*da0073e9SAndroid Build Coastguard Worker sample_input.args[0], 312*da0073e9SAndroid Build Coastguard Worker args=(sample_input.input,), 313*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input.kwargs, 314*da0073e9SAndroid Build Coastguard Worker ) 315*da0073e9SAndroid Build Coastguard Worker sample_input.input.requires_grad_(False) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Checking error message doesn't work with dynamo") 320*da0073e9SAndroid Build Coastguard Worker @ops( 321*da0073e9SAndroid Build Coastguard Worker filter(lambda op: op.supports_expanded_weight, op_db), 322*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.supported, 323*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.double,), 324*da0073e9SAndroid Build Coastguard Worker ) 325*da0073e9SAndroid Build Coastguard Worker def test_unsupported_expand_weights(self, device, dtype, op): 326*da0073e9SAndroid Build Coastguard Worker sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 327*da0073e9SAndroid Build Coastguard Worker unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False) 328*da0073e9SAndroid Build Coastguard Worker for sample_input in unsupported_inputs: 329*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"): 330*da0073e9SAndroid Build Coastguard Worker if ( 331*da0073e9SAndroid Build Coastguard Worker op.name == "nn.functional.embedding" 332*da0073e9SAndroid Build Coastguard Worker ): # embedding flips its argument order for autograd tests 333*da0073e9SAndroid Build Coastguard Worker sample_input = SampleInput( 334*da0073e9SAndroid Build Coastguard Worker sample_input.args[0], 335*da0073e9SAndroid Build Coastguard Worker args=(sample_input.input,), 336*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input.kwargs, 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker input = sample_input.input 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker batch_size = input.shape[0] if len(input.shape) > 1 else 1 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # get per sample grads with ExpandedWeights objects 343*da0073e9SAndroid Build Coastguard Worker (ew_input, ew_args, ew_kwargs) = make_expanded_weight( 344*da0073e9SAndroid Build Coastguard Worker sample_input, batch_size 345*da0073e9SAndroid Build Coastguard Worker ) 346*da0073e9SAndroid Build Coastguard Worker result = run_op(op, ew_input, *ew_args, **ew_kwargs) 347*da0073e9SAndroid Build Coastguard Worker diff_input_list = ( 348*da0073e9SAndroid Build Coastguard Worker (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) 349*da0073e9SAndroid Build Coastguard Worker ) 350*da0073e9SAndroid Build Coastguard Worker diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] 351*da0073e9SAndroid Build Coastguard Worker diff_input_list = [ 352*da0073e9SAndroid Build Coastguard Worker i.orig_weight if isinstance(i, ExpandedWeight) else i 353*da0073e9SAndroid Build Coastguard Worker for i in diff_input_list 354*da0073e9SAndroid Build Coastguard Worker ] 355*da0073e9SAndroid Build Coastguard Worker result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker @ops( 358*da0073e9SAndroid Build Coastguard Worker filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported 359*da0073e9SAndroid Build Coastguard Worker ) 360*da0073e9SAndroid Build Coastguard Worker def test_expanded_weight_forward(self, device, dtype, op): 361*da0073e9SAndroid Build Coastguard Worker sample_inputs = op.sample_inputs(device, dtype) 362*da0073e9SAndroid Build Coastguard Worker for sample_input in supported_inputs(op, sample_inputs): 363*da0073e9SAndroid Build Coastguard Worker if ( 364*da0073e9SAndroid Build Coastguard Worker op.name == "nn.functional.embedding" 365*da0073e9SAndroid Build Coastguard Worker ): # embedding flips its argument order for autograd tests 366*da0073e9SAndroid Build Coastguard Worker sample_input = SampleInput( 367*da0073e9SAndroid Build Coastguard Worker sample_input.args[0].clone(), 368*da0073e9SAndroid Build Coastguard Worker args=(sample_input.input.clone(),), 369*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input.kwargs, 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker if ( 372*da0073e9SAndroid Build Coastguard Worker "cuda" in device 373*da0073e9SAndroid Build Coastguard Worker and "max_norm" in sample_input.kwargs 374*da0073e9SAndroid Build Coastguard Worker and "padding_idx" in sample_input.kwargs 375*da0073e9SAndroid Build Coastguard Worker ): 376*da0073e9SAndroid Build Coastguard Worker self.skipTest( 377*da0073e9SAndroid Build Coastguard Worker "embedding is non-determinstic in this case, see issue #74679" 378*da0073e9SAndroid Build Coastguard Worker ) 379*da0073e9SAndroid Build Coastguard Worker batch_size = ( 380*da0073e9SAndroid Build Coastguard Worker sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1 381*da0073e9SAndroid Build Coastguard Worker ) 382*da0073e9SAndroid Build Coastguard Worker for loss_reduction in ["sum", "mean"]: 383*da0073e9SAndroid Build Coastguard Worker (ew_input, ew_args, ew_kwargs) = make_expanded_weight( 384*da0073e9SAndroid Build Coastguard Worker sample_input, batch_size, loss_reduction 385*da0073e9SAndroid Build Coastguard Worker ) 386*da0073e9SAndroid Build Coastguard Worker expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs) 387*da0073e9SAndroid Build Coastguard Worker normal_result = run_op( 388*da0073e9SAndroid Build Coastguard Worker op, sample_input.input, *sample_input.args, **sample_input.kwargs 389*da0073e9SAndroid Build Coastguard Worker ) 390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded_weight_result, normal_result) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker def test_expanded_weight_error(self, device): 393*da0073e9SAndroid Build Coastguard Worker batch_size = 3 394*da0073e9SAndroid Build Coastguard Worker sample_input = make_tensor( 395*da0073e9SAndroid Build Coastguard Worker (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True 396*da0073e9SAndroid Build Coastguard Worker ) 397*da0073e9SAndroid Build Coastguard Worker sample_weight = make_tensor( 398*da0073e9SAndroid Build Coastguard Worker (4), dtype=torch.float32, device=device, requires_grad=True 399*da0073e9SAndroid Build Coastguard Worker ) 400*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 401*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expanded Weights encountered but cannot handle function" 402*da0073e9SAndroid Build Coastguard Worker ): 403*da0073e9SAndroid Build Coastguard Worker torch.add( 404*da0073e9SAndroid Build Coastguard Worker sample_input, 405*da0073e9SAndroid Build Coastguard Worker ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"), 406*da0073e9SAndroid Build Coastguard Worker ) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker def _test_embedding_model(self, model, num_embedding, device): 409*da0073e9SAndroid Build Coastguard Worker batch_size = 32 410*da0073e9SAndroid Build Coastguard Worker input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device) 411*da0073e9SAndroid Build Coastguard Worker return self._test_model( 412*da0073e9SAndroid Build Coastguard Worker partial(model, num_embedding=num_embedding), batch_size, input, device 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker def _test_conv_model( 416*da0073e9SAndroid Build Coastguard Worker self, 417*da0073e9SAndroid Build Coastguard Worker model, 418*da0073e9SAndroid Build Coastguard Worker input_size, 419*da0073e9SAndroid Build Coastguard Worker num_dim, 420*da0073e9SAndroid Build Coastguard Worker device, 421*da0073e9SAndroid Build Coastguard Worker loss_reduction="sum", 422*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 423*da0073e9SAndroid Build Coastguard Worker rtol=5e-5, 424*da0073e9SAndroid Build Coastguard Worker ): 425*da0073e9SAndroid Build Coastguard Worker batch_size = 32 426*da0073e9SAndroid Build Coastguard Worker input_ending = [input_size] * num_dim 427*da0073e9SAndroid Build Coastguard Worker input = torch.randn([batch_size, 3] + input_ending, device=device) 428*da0073e9SAndroid Build Coastguard Worker return self._test_model( 429*da0073e9SAndroid Build Coastguard Worker partial(model, num_dim=num_dim), 430*da0073e9SAndroid Build Coastguard Worker batch_size, 431*da0073e9SAndroid Build Coastguard Worker input, 432*da0073e9SAndroid Build Coastguard Worker device, 433*da0073e9SAndroid Build Coastguard Worker loss_reduction, 434*da0073e9SAndroid Build Coastguard Worker atol, 435*da0073e9SAndroid Build Coastguard Worker rtol, 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def _test_model( 439*da0073e9SAndroid Build Coastguard Worker self, 440*da0073e9SAndroid Build Coastguard Worker model, 441*da0073e9SAndroid Build Coastguard Worker batch_size, 442*da0073e9SAndroid Build Coastguard Worker input, 443*da0073e9SAndroid Build Coastguard Worker device, 444*da0073e9SAndroid Build Coastguard Worker loss_reduction="sum", 445*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 446*da0073e9SAndroid Build Coastguard Worker rtol=5e-5, 447*da0073e9SAndroid Build Coastguard Worker ): 448*da0073e9SAndroid Build Coastguard Worker model = model(10).to(device) 449*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(0, 10, (batch_size,), device=device) 450*da0073e9SAndroid Build Coastguard Worker criterion = CrossEntropyLoss(reduction=loss_reduction) 451*da0073e9SAndroid Build Coastguard Worker result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input) 452*da0073e9SAndroid Build Coastguard Worker loss = criterion(result, targets) 453*da0073e9SAndroid Build Coastguard Worker loss.backward() 454*da0073e9SAndroid Build Coastguard Worker result = [] 455*da0073e9SAndroid Build Coastguard Worker for weight in model.parameters(): 456*da0073e9SAndroid Build Coastguard Worker result.append(weight.grad_sample) 457*da0073e9SAndroid Build Coastguard Worker del weight.grad_sample 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker expected = [] 460*da0073e9SAndroid Build Coastguard Worker for i in range(batch_size): 461*da0073e9SAndroid Build Coastguard Worker loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) 462*da0073e9SAndroid Build Coastguard Worker expected.append( 463*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss)) 464*da0073e9SAndroid Build Coastguard Worker ) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker expected = [torch.stack(grad) for grad in zip(*expected)] 467*da0073e9SAndroid Build Coastguard Worker for res, exp in zip(result, expected): 468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, exp, atol=atol, rtol=rtol) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker def _compute_tolerances(self, device): 471*da0073e9SAndroid Build Coastguard Worker is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability( 472*da0073e9SAndroid Build Coastguard Worker 0 473*da0073e9SAndroid Build Coastguard Worker ) == (8, 6) 474*da0073e9SAndroid Build Coastguard Worker return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker @tf32_off() 477*da0073e9SAndroid Build Coastguard Worker def test_cnn_model_sum(self, device): 478*da0073e9SAndroid Build Coastguard Worker def convnet(num_classes, num_dim): 479*da0073e9SAndroid Build Coastguard Worker return nn.Sequential( 480*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), 481*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 482*da0073e9SAndroid Build Coastguard Worker nn.AvgPool2d(kernel_size=2, stride=2), 483*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 484*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 485*da0073e9SAndroid Build Coastguard Worker nn.AvgPool2d(kernel_size=2, stride=2), 486*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 487*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 488*da0073e9SAndroid Build Coastguard Worker nn.AvgPool2d(kernel_size=2, stride=2), 489*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 490*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 491*da0073e9SAndroid Build Coastguard Worker nn.AdaptiveAvgPool2d((1, 1)), 492*da0073e9SAndroid Build Coastguard Worker nn.Flatten(start_dim=1, end_dim=-1), 493*da0073e9SAndroid Build Coastguard Worker nn.Linear(128, num_classes, bias=True), 494*da0073e9SAndroid Build Coastguard Worker ) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._compute_tolerances(device) 497*da0073e9SAndroid Build Coastguard Worker return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker @tf32_off() 500*da0073e9SAndroid Build Coastguard Worker def test_cnn_model_mean(self, device): 501*da0073e9SAndroid Build Coastguard Worker def convnet(num_classes, num_dim): 502*da0073e9SAndroid Build Coastguard Worker return nn.Sequential( 503*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), 504*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 505*da0073e9SAndroid Build Coastguard Worker nn.AvgPool2d(kernel_size=2, stride=2), 506*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 507*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 508*da0073e9SAndroid Build Coastguard Worker nn.AvgPool2d(kernel_size=2, stride=2), 509*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 510*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 511*da0073e9SAndroid Build Coastguard Worker nn.AvgPool2d(kernel_size=2, stride=2), 512*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 513*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 514*da0073e9SAndroid Build Coastguard Worker nn.AdaptiveAvgPool2d((1, 1)), 515*da0073e9SAndroid Build Coastguard Worker nn.Flatten(start_dim=1, end_dim=-1), 516*da0073e9SAndroid Build Coastguard Worker nn.Linear(128, num_classes, bias=True), 517*da0073e9SAndroid Build Coastguard Worker ) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._compute_tolerances(device) 520*da0073e9SAndroid Build Coastguard Worker return self._test_conv_model( 521*da0073e9SAndroid Build Coastguard Worker convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol 522*da0073e9SAndroid Build Coastguard Worker ) 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker @parametrize("num_dim", [1, 2, 3]) 525*da0073e9SAndroid Build Coastguard Worker @tf32_off() 526*da0073e9SAndroid Build Coastguard Worker def test_instance_norm_model(self, num_dim, device): 527*da0073e9SAndroid Build Coastguard Worker def instance_norm_model(num_classes, num_dim): 528*da0073e9SAndroid Build Coastguard Worker conv_layer = ( 529*da0073e9SAndroid Build Coastguard Worker nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d 530*da0073e9SAndroid Build Coastguard Worker ) 531*da0073e9SAndroid Build Coastguard Worker norm_layer = ( 532*da0073e9SAndroid Build Coastguard Worker nn.InstanceNorm1d 533*da0073e9SAndroid Build Coastguard Worker if num_dim == 1 534*da0073e9SAndroid Build Coastguard Worker else nn.InstanceNorm2d 535*da0073e9SAndroid Build Coastguard Worker if num_dim == 2 536*da0073e9SAndroid Build Coastguard Worker else nn.InstanceNorm3d 537*da0073e9SAndroid Build Coastguard Worker ) 538*da0073e9SAndroid Build Coastguard Worker return nn.Sequential( 539*da0073e9SAndroid Build Coastguard Worker conv_layer(3, 32, kernel_size=3, stride=1, padding=1), 540*da0073e9SAndroid Build Coastguard Worker norm_layer(32, affine=True), 541*da0073e9SAndroid Build Coastguard Worker nn.Flatten(start_dim=1, end_dim=-1), 542*da0073e9SAndroid Build Coastguard Worker nn.Linear(32 * (7**num_dim), num_classes, bias=True), 543*da0073e9SAndroid Build Coastguard Worker ) 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._compute_tolerances(device) 546*da0073e9SAndroid Build Coastguard Worker return self._test_conv_model( 547*da0073e9SAndroid Build Coastguard Worker instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol 548*da0073e9SAndroid Build Coastguard Worker ) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker @parametrize("num_dim", [1, 2, 3]) 551*da0073e9SAndroid Build Coastguard Worker @tf32_off() 552*da0073e9SAndroid Build Coastguard Worker def test_group_norm_model(self, num_dim, device): 553*da0073e9SAndroid Build Coastguard Worker def group_norm_model(num_classes, num_dim): 554*da0073e9SAndroid Build Coastguard Worker conv_layer = ( 555*da0073e9SAndroid Build Coastguard Worker nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d 556*da0073e9SAndroid Build Coastguard Worker ) 557*da0073e9SAndroid Build Coastguard Worker return nn.Sequential( 558*da0073e9SAndroid Build Coastguard Worker conv_layer(3, 32, kernel_size=3, stride=1, padding=1), 559*da0073e9SAndroid Build Coastguard Worker nn.GroupNorm(8, 32, affine=True), 560*da0073e9SAndroid Build Coastguard Worker nn.Flatten(start_dim=1, end_dim=-1), 561*da0073e9SAndroid Build Coastguard Worker nn.Linear(32 * (7**num_dim), num_classes, bias=True), 562*da0073e9SAndroid Build Coastguard Worker ) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._compute_tolerances(device) 565*da0073e9SAndroid Build Coastguard Worker return self._test_conv_model( 566*da0073e9SAndroid Build Coastguard Worker group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol 567*da0073e9SAndroid Build Coastguard Worker ) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker @parametrize("num_dim", [1, 2, 3]) 570*da0073e9SAndroid Build Coastguard Worker @tf32_off() 571*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_model(self, num_dim, device): 572*da0073e9SAndroid Build Coastguard Worker def layer_norm_model(num_classes, num_dim): 573*da0073e9SAndroid Build Coastguard Worker conv_layer = ( 574*da0073e9SAndroid Build Coastguard Worker nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d 575*da0073e9SAndroid Build Coastguard Worker ) 576*da0073e9SAndroid Build Coastguard Worker normalized_shape = [7] * num_dim 577*da0073e9SAndroid Build Coastguard Worker return nn.Sequential( 578*da0073e9SAndroid Build Coastguard Worker conv_layer(3, 32, kernel_size=3, stride=1, padding=1), 579*da0073e9SAndroid Build Coastguard Worker nn.LayerNorm(normalized_shape, elementwise_affine=True), 580*da0073e9SAndroid Build Coastguard Worker nn.Flatten(start_dim=1, end_dim=-1), 581*da0073e9SAndroid Build Coastguard Worker nn.Linear(32 * (7**num_dim), num_classes, bias=True), 582*da0073e9SAndroid Build Coastguard Worker ) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._compute_tolerances(device) 585*da0073e9SAndroid Build Coastguard Worker return self._test_conv_model( 586*da0073e9SAndroid Build Coastguard Worker layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol 587*da0073e9SAndroid Build Coastguard Worker ) 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker def test_embedding_model(self, device): 590*da0073e9SAndroid Build Coastguard Worker def embedding_model(num_classes, num_embedding): 591*da0073e9SAndroid Build Coastguard Worker return nn.Sequential( 592*da0073e9SAndroid Build Coastguard Worker nn.Embedding(num_embedding, 15), 593*da0073e9SAndroid Build Coastguard Worker nn.Flatten(start_dim=1, end_dim=-1), 594*da0073e9SAndroid Build Coastguard Worker nn.Linear(375, num_classes, bias=True), 595*da0073e9SAndroid Build Coastguard Worker ) 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker return self._test_embedding_model(embedding_model, 16, device) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def test_group_norm_error(self, device): 600*da0073e9SAndroid Build Coastguard Worker # group norm has to call native_group_norm. This checks that it hits the same errors 601*da0073e9SAndroid Build Coastguard Worker # that normal group norm would 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker N = 3 604*da0073e9SAndroid Build Coastguard Worker C = 5 605*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C) 606*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 607*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected number of channels in input to be divisible" 608*da0073e9SAndroid Build Coastguard Worker ): 609*da0073e9SAndroid Build Coastguard Worker F.group_norm(inp, 2) # 5 is not divisible by 2 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Workerclass TestExpandedWeightModule(TestCase): 613*da0073e9SAndroid Build Coastguard Worker def _do_test( 614*da0073e9SAndroid Build Coastguard Worker self, 615*da0073e9SAndroid Build Coastguard Worker module, 616*da0073e9SAndroid Build Coastguard Worker input, 617*da0073e9SAndroid Build Coastguard Worker args=None, 618*da0073e9SAndroid Build Coastguard Worker kwargs=None, 619*da0073e9SAndroid Build Coastguard Worker batch_first=True, 620*da0073e9SAndroid Build Coastguard Worker atol=None, 621*da0073e9SAndroid Build Coastguard Worker rtol=None, 622*da0073e9SAndroid Build Coastguard Worker ): 623*da0073e9SAndroid Build Coastguard Worker args = args or () 624*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs or {} 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker batch_dim = 0 if batch_first else 1 627*da0073e9SAndroid Build Coastguard Worker batch_size = input.shape[batch_dim] 628*da0073e9SAndroid Build Coastguard Worker diff_input = input.dtype == torch.float or input.dtype == torch.double 629*da0073e9SAndroid Build Coastguard Worker if diff_input: 630*da0073e9SAndroid Build Coastguard Worker input.requires_grad_() 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 633*da0073e9SAndroid Build Coastguard Worker # get per sample grads with ExpandedWeights context manager 634*da0073e9SAndroid Build Coastguard Worker actual_res = call_for_per_sample_grads( 635*da0073e9SAndroid Build Coastguard Worker module, 636*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 637*da0073e9SAndroid Build Coastguard Worker loss_reduction="sum", 638*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 639*da0073e9SAndroid Build Coastguard Worker )(input, *args, **kwargs).sum() 640*da0073e9SAndroid Build Coastguard Worker actual_res.backward() 641*da0073e9SAndroid Build Coastguard Worker actual_grads = [] 642*da0073e9SAndroid Build Coastguard Worker for param in module.parameters(): 643*da0073e9SAndroid Build Coastguard Worker actual_grads.append(param.grad_sample) 644*da0073e9SAndroid Build Coastguard Worker del param.grad_sample 645*da0073e9SAndroid Build Coastguard Worker if diff_input: 646*da0073e9SAndroid Build Coastguard Worker actual_grads.append(input.grad.clone()) 647*da0073e9SAndroid Build Coastguard Worker input.grad = torch.zeros_like(input.grad) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker # get per sample grads with a for loop 650*da0073e9SAndroid Build Coastguard Worker expected_res = torch.tensor( 651*da0073e9SAndroid Build Coastguard Worker 0.0, device=input.device, dtype=actual_res.dtype 652*da0073e9SAndroid Build Coastguard Worker ) 653*da0073e9SAndroid Build Coastguard Worker expected_grads = [] 654*da0073e9SAndroid Build Coastguard Worker for i in range(batch_size): 655*da0073e9SAndroid Build Coastguard Worker input_slice = input.narrow(batch_dim, i, 1) 656*da0073e9SAndroid Build Coastguard Worker input_slice = input_slice.squeeze(batch_dim) 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker # h's batch dim is always the first dim. Must be contiguous for CUDA 659*da0073e9SAndroid Build Coastguard Worker sliced_args = tree_map_only( 660*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args 661*da0073e9SAndroid Build Coastguard Worker ) 662*da0073e9SAndroid Build Coastguard Worker diff_params = module.parameters() 663*da0073e9SAndroid Build Coastguard Worker if diff_input: 664*da0073e9SAndroid Build Coastguard Worker diff_params = chain(diff_params, (input_slice,)) 665*da0073e9SAndroid Build Coastguard Worker res = module( 666*da0073e9SAndroid Build Coastguard Worker input_slice.unsqueeze(batch_dim).contiguous(), 667*da0073e9SAndroid Build Coastguard Worker *sliced_args, 668*da0073e9SAndroid Build Coastguard Worker **kwargs, 669*da0073e9SAndroid Build Coastguard Worker ).sum() 670*da0073e9SAndroid Build Coastguard Worker out_grads = torch.autograd.grad( 671*da0073e9SAndroid Build Coastguard Worker res, diff_params, torch.ones_like(res), allow_unused=True 672*da0073e9SAndroid Build Coastguard Worker ) 673*da0073e9SAndroid Build Coastguard Worker expected_grads.append(out_grads) 674*da0073e9SAndroid Build Coastguard Worker expected_res += res 675*da0073e9SAndroid Build Coastguard Worker expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] 676*da0073e9SAndroid Build Coastguard Worker if not batch_first: 677*da0073e9SAndroid Build Coastguard Worker expected_grads[-1] = expected_grads[-1].transpose(0, 1) 678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_res, expected_res) 679*da0073e9SAndroid Build Coastguard Worker [ 680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=atol, rtol=rtol) 681*da0073e9SAndroid Build Coastguard Worker for (actual, expected) in zip(actual_grads, expected_grads) 682*da0073e9SAndroid Build Coastguard Worker ] 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker def _do_test_multi_input(self, module, input): 685*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 686*da0073e9SAndroid Build Coastguard Worker def __init__(self, module): 687*da0073e9SAndroid Build Coastguard Worker super().__init__() 688*da0073e9SAndroid Build Coastguard Worker self.module = module 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 691*da0073e9SAndroid Build Coastguard Worker return self.module(input) + self.module(input) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker batch_size = input.shape[0] 694*da0073e9SAndroid Build Coastguard Worker diff_input = input.dtype == torch.float or input.dtype == torch.double 695*da0073e9SAndroid Build Coastguard Worker if diff_input: 696*da0073e9SAndroid Build Coastguard Worker input.requires_grad_() 697*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 698*da0073e9SAndroid Build Coastguard Worker # get per sample grads with ExpandedWeights context manager, calling .backward() twice 699*da0073e9SAndroid Build Coastguard Worker test_module = TestModule(module) 700*da0073e9SAndroid Build Coastguard Worker actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")( 701*da0073e9SAndroid Build Coastguard Worker input 702*da0073e9SAndroid Build Coastguard Worker ).sum() 703*da0073e9SAndroid Build Coastguard Worker actual_res.backward() 704*da0073e9SAndroid Build Coastguard Worker actual_grads = [] 705*da0073e9SAndroid Build Coastguard Worker for param in module.parameters(): 706*da0073e9SAndroid Build Coastguard Worker actual_grads.append(param.grad_sample) 707*da0073e9SAndroid Build Coastguard Worker del param.grad_sample 708*da0073e9SAndroid Build Coastguard Worker if diff_input: 709*da0073e9SAndroid Build Coastguard Worker actual_grads.append(input.grad.clone()) 710*da0073e9SAndroid Build Coastguard Worker input.grad = torch.zeros_like(input.grad) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker # get per sample grads with a for loop, running over the input twice 713*da0073e9SAndroid Build Coastguard Worker expected_grads = [] 714*da0073e9SAndroid Build Coastguard Worker for i in range(batch_size): 715*da0073e9SAndroid Build Coastguard Worker input_slice = input[i] 716*da0073e9SAndroid Build Coastguard Worker diff_params = module.parameters() 717*da0073e9SAndroid Build Coastguard Worker if diff_input: 718*da0073e9SAndroid Build Coastguard Worker diff_params = chain(diff_params, (input_slice,)) 719*da0073e9SAndroid Build Coastguard Worker res = module(input_slice.unsqueeze(0)).sum() 720*da0073e9SAndroid Build Coastguard Worker out_grads = torch.autograd.grad( 721*da0073e9SAndroid Build Coastguard Worker res, diff_params, torch.ones_like(res), allow_unused=True 722*da0073e9SAndroid Build Coastguard Worker ) 723*da0073e9SAndroid Build Coastguard Worker expected_grads.append(out_grads) 724*da0073e9SAndroid Build Coastguard Worker expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads)) 725*da0073e9SAndroid Build Coastguard Worker expected_grads = tuple( 726*da0073e9SAndroid Build Coastguard Worker expected_grad 727*da0073e9SAndroid Build Coastguard Worker for expected_grad in expected_grads 728*da0073e9SAndroid Build Coastguard Worker if expected_grad is not None 729*da0073e9SAndroid Build Coastguard Worker ) 730*da0073e9SAndroid Build Coastguard Worker assert [ 731*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, 2 * expected) 732*da0073e9SAndroid Build Coastguard Worker for (actual, expected) in zip(actual_grads, expected_grads) 733*da0073e9SAndroid Build Coastguard Worker ] 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def _do_test_rnn_packed_sequence( 736*da0073e9SAndroid Build Coastguard Worker self, module, input, args=None, kwargs=None, atol=None, rtol=None 737*da0073e9SAndroid Build Coastguard Worker ): 738*da0073e9SAndroid Build Coastguard Worker args = args if args is not None else () 739*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs if kwargs is not None else {} 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker batch_size = max(tuple(input.batch_sizes)).item() 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 744*da0073e9SAndroid Build Coastguard Worker # get per sample grads with ExpandedWeights context manager 745*da0073e9SAndroid Build Coastguard Worker actual_res = call_for_per_sample_grads( 746*da0073e9SAndroid Build Coastguard Worker module, batch_size=batch_size, loss_reduction="sum" 747*da0073e9SAndroid Build Coastguard Worker )(input, *args, **kwargs).data.sum() 748*da0073e9SAndroid Build Coastguard Worker actual_res.backward() 749*da0073e9SAndroid Build Coastguard Worker actual_grads = [] 750*da0073e9SAndroid Build Coastguard Worker for param in module.parameters(): 751*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param.grad_sample.shape[0], batch_size) 752*da0073e9SAndroid Build Coastguard Worker actual_grads.append(param.grad_sample) 753*da0073e9SAndroid Build Coastguard Worker del param.grad_sample 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker input.data.grad = torch.zeros_like(input.data) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker # compute the per sample grads with a for loop 758*da0073e9SAndroid Build Coastguard Worker expected_res = torch.zeros_like(actual_res) 759*da0073e9SAndroid Build Coastguard Worker expected_grads = [] 760*da0073e9SAndroid Build Coastguard Worker padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence( 761*da0073e9SAndroid Build Coastguard Worker input, batch_first=True 762*da0073e9SAndroid Build Coastguard Worker ) 763*da0073e9SAndroid Build Coastguard Worker for i in range(len(seq_sizes)): 764*da0073e9SAndroid Build Coastguard Worker input_slice = padded_input[i].narrow(0, 0, seq_sizes[i]) 765*da0073e9SAndroid Build Coastguard Worker diff_params = module.parameters() 766*da0073e9SAndroid Build Coastguard Worker batch_dim = 0 if module.m.batch_first else 1 767*da0073e9SAndroid Build Coastguard Worker res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum() 768*da0073e9SAndroid Build Coastguard Worker expected_res += res 769*da0073e9SAndroid Build Coastguard Worker out_grads = torch.autograd.grad( 770*da0073e9SAndroid Build Coastguard Worker res, diff_params, torch.ones_like(res), allow_unused=True 771*da0073e9SAndroid Build Coastguard Worker ) 772*da0073e9SAndroid Build Coastguard Worker expected_grads.append(out_grads) 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] 775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_res, expected_res) 776*da0073e9SAndroid Build Coastguard Worker [ 777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=atol, rtol=rtol) 778*da0073e9SAndroid Build Coastguard Worker for (actual, expected) in zip(actual_grads, expected_grads) 779*da0073e9SAndroid Build Coastguard Worker ] 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker @modules( 782*da0073e9SAndroid Build Coastguard Worker filter( 783*da0073e9SAndroid Build Coastguard Worker lambda m_info: m_info.module_cls 784*da0073e9SAndroid Build Coastguard Worker in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), 785*da0073e9SAndroid Build Coastguard Worker module_db, 786*da0073e9SAndroid Build Coastguard Worker ) 787*da0073e9SAndroid Build Coastguard Worker ) 788*da0073e9SAndroid Build Coastguard Worker @tf32_off() 789*da0073e9SAndroid Build Coastguard Worker def test_module(self, device, dtype, module_info, training): 790*da0073e9SAndroid Build Coastguard Worker class RNNWrapper(torch.nn.Module): 791*da0073e9SAndroid Build Coastguard Worker def __init__(self, m_cons, args, kwargs): 792*da0073e9SAndroid Build Coastguard Worker super().__init__() 793*da0073e9SAndroid Build Coastguard Worker self.m = m_cons(*args, **kwargs) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker def forward(self, *inps): 796*da0073e9SAndroid Build Coastguard Worker ret = self.m(*inps) 797*da0073e9SAndroid Build Coastguard Worker assert isinstance(ret, tuple) 798*da0073e9SAndroid Build Coastguard Worker return ret[0] 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker def batch_hidden(h): 801*da0073e9SAndroid Build Coastguard Worker new_h_shape = [1] * (len(h.shape) + 1) 802*da0073e9SAndroid Build Coastguard Worker new_h_shape[1] = 2 803*da0073e9SAndroid Build Coastguard Worker return h.unsqueeze(1).repeat(new_h_shape) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker module_cls = module_info.module_cls 806*da0073e9SAndroid Build Coastguard Worker atol, rtol = ( 807*da0073e9SAndroid Build Coastguard Worker (1e-4, 1e-5) 808*da0073e9SAndroid Build Coastguard Worker if module_cls == torch.nn.GRU and dtype == torch.float32 809*da0073e9SAndroid Build Coastguard Worker else (None, None) 810*da0073e9SAndroid Build Coastguard Worker ) 811*da0073e9SAndroid Build Coastguard Worker module_inputs = module_info.module_inputs_func( 812*da0073e9SAndroid Build Coastguard Worker module_info, 813*da0073e9SAndroid Build Coastguard Worker device=device, 814*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 815*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 816*da0073e9SAndroid Build Coastguard Worker training=training, 817*da0073e9SAndroid Build Coastguard Worker with_packed_sequence=True, 818*da0073e9SAndroid Build Coastguard Worker ) 819*da0073e9SAndroid Build Coastguard Worker for module_input in module_inputs: 820*da0073e9SAndroid Build Coastguard Worker if module_input.forward_input is None: 821*da0073e9SAndroid Build Coastguard Worker continue 822*da0073e9SAndroid Build Coastguard Worker args, kwargs = ( 823*da0073e9SAndroid Build Coastguard Worker module_input.constructor_input.args, 824*da0073e9SAndroid Build Coastguard Worker module_input.constructor_input.kwargs, 825*da0073e9SAndroid Build Coastguard Worker ) 826*da0073e9SAndroid Build Coastguard Worker m = RNNWrapper(module_cls, args, kwargs) 827*da0073e9SAndroid Build Coastguard Worker batch_first = m.m.batch_first 828*da0073e9SAndroid Build Coastguard Worker m.to(device).to(dtype) 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker args, kwargs = ( 831*da0073e9SAndroid Build Coastguard Worker module_input.forward_input.args, 832*da0073e9SAndroid Build Coastguard Worker module_input.forward_input.kwargs, 833*da0073e9SAndroid Build Coastguard Worker ) 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker # if the RNN tests use unbatched inputs--batch the inputs 836*da0073e9SAndroid Build Coastguard Worker input = args[0] 837*da0073e9SAndroid Build Coastguard Worker if isinstance(input, torch.Tensor) and input.dim() == 2: 838*da0073e9SAndroid Build Coastguard Worker input = input.detach() 839*da0073e9SAndroid Build Coastguard Worker new_input_shape = [1] * (len(input.shape) + 1) 840*da0073e9SAndroid Build Coastguard Worker if batch_first: 841*da0073e9SAndroid Build Coastguard Worker new_input_shape[0] = 2 842*da0073e9SAndroid Build Coastguard Worker input = input.repeat(new_input_shape) 843*da0073e9SAndroid Build Coastguard Worker else: 844*da0073e9SAndroid Build Coastguard Worker new_input_shape[1] = 2 845*da0073e9SAndroid Build Coastguard Worker input = input.unsqueeze(1).repeat(new_input_shape) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker h = args[1] if len(args) > 1 else None 848*da0073e9SAndroid Build Coastguard Worker if h is not None: 849*da0073e9SAndroid Build Coastguard Worker h = ( 850*da0073e9SAndroid Build Coastguard Worker batch_hidden(h) 851*da0073e9SAndroid Build Coastguard Worker if isinstance(h, torch.Tensor) 852*da0073e9SAndroid Build Coastguard Worker else tuple(batch_hidden(hx) for hx in h) 853*da0073e9SAndroid Build Coastguard Worker ) 854*da0073e9SAndroid Build Coastguard Worker args = list(args) 855*da0073e9SAndroid Build Coastguard Worker args[1] = h 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker if isinstance(input, torch.nn.utils.rnn.PackedSequence): 858*da0073e9SAndroid Build Coastguard Worker self._do_test_rnn_packed_sequence( 859*da0073e9SAndroid Build Coastguard Worker m, input, args[1:], kwargs, atol=atol, rtol=rtol 860*da0073e9SAndroid Build Coastguard Worker ) 861*da0073e9SAndroid Build Coastguard Worker else: 862*da0073e9SAndroid Build Coastguard Worker self._do_test( 863*da0073e9SAndroid Build Coastguard Worker m, 864*da0073e9SAndroid Build Coastguard Worker input, 865*da0073e9SAndroid Build Coastguard Worker args[1:], 866*da0073e9SAndroid Build Coastguard Worker kwargs, 867*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 868*da0073e9SAndroid Build Coastguard Worker atol=atol, 869*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 870*da0073e9SAndroid Build Coastguard Worker ) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker def test_per_sample_api_failing(self): 873*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(10, 10) 874*da0073e9SAndroid Build Coastguard Worker input = torch.randn(64, 10) 875*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): 876*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads("fail")(input) 877*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 878*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Batch size passed must be None or an integer" 879*da0073e9SAndroid Build Coastguard Worker ): 880*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module, batch_size=6.4)(input) 881*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): 882*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module, batch_size=-64)(input) 883*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): 884*da0073e9SAndroid Build Coastguard Worker loss = call_for_per_sample_grads(module)(input).sum() 885*da0073e9SAndroid Build Coastguard Worker loss.backward() # populate grad_sample fields 886*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module)(input) 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(10, 10) # reset to not have grad_sample fields 889*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 890*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected loss_reduction argument to be sum or mean" 891*da0073e9SAndroid Build Coastguard Worker ): 892*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module, loss_reduction="")(input) 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker def test_per_sample_api_compute_batch_size(self): 895*da0073e9SAndroid Build Coastguard Worker class CustomModule(nn.Module): 896*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 897*da0073e9SAndroid Build Coastguard Worker super().__init__() 898*da0073e9SAndroid Build Coastguard Worker self.linear = nn.Linear(5, 5) 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker def forward(self, input1, input2): 901*da0073e9SAndroid Build Coastguard Worker return self.linear(input1) + self.linear(input2) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker module = CustomModule() 904*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(4, 5) 905*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(5, 5) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 908*da0073e9SAndroid Build Coastguard Worker RuntimeError, 909*da0073e9SAndroid Build Coastguard Worker "found at least one input with batch size 4 and one with batch size 5", 910*da0073e9SAndroid Build Coastguard Worker ): 911*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module)(input1, input2) 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(4, 5) 914*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module)(input1, input2) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker module = CustomModule() 917*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module)(input1, input2=input2) 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker module = CustomModule() 920*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(module)(input1=input1, input2=input2) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker def test_per_sample_api_compute_batch_size_not_pytreeable(self): 923*da0073e9SAndroid Build Coastguard Worker @dataclass 924*da0073e9SAndroid Build Coastguard Worker class NonPytreeableTuple: 925*da0073e9SAndroid Build Coastguard Worker elem1: torch.Tensor 926*da0073e9SAndroid Build Coastguard Worker elem2: torch.Tensor 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker class CustomModule(nn.Module): 929*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 930*da0073e9SAndroid Build Coastguard Worker super().__init__() 931*da0073e9SAndroid Build Coastguard Worker self.linear = nn.Linear(5, 5) 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker def forward(self, input1, input2): 934*da0073e9SAndroid Build Coastguard Worker return self.linear(input1.elem1) + self.linear(input1.elem2) 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5)) 937*da0073e9SAndroid Build Coastguard Worker model = CustomModule() 938*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 939*da0073e9SAndroid Build Coastguard Worker RuntimeError, 940*da0073e9SAndroid Build Coastguard Worker "ExpandedWeights cannot compute the batch size from the inputs", 941*da0073e9SAndroid Build Coastguard Worker ): 942*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(model)(input, "") 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker # would prefer for it to error because input is not pytree-able but that's hard to detect 945*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 946*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected ExpandedWeights to have batch size matching input" 947*da0073e9SAndroid Build Coastguard Worker ): 948*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(model)(input, torch.randn(5)) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker model = CustomModule() # TODO: functional call bug, sam will fix 951*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(model)(input, torch.randn(4, 5)) 952*da0073e9SAndroid Build Coastguard Worker model = CustomModule() 953*da0073e9SAndroid Build Coastguard Worker call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5)) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Workerclass ContextManagerTests(TestBase): 957*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args, **kwargs): 958*da0073e9SAndroid Build Coastguard Worker self.test_cpu = kwargs.get("test_cpu", True) 959*da0073e9SAndroid Build Coastguard Worker self.test_cuda = kwargs.get("test_cuda", True) 960*da0073e9SAndroid Build Coastguard Worker super().__init__(*args, **kwargs) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker @property 963*da0073e9SAndroid Build Coastguard Worker def constructor_args(self): 964*da0073e9SAndroid Build Coastguard Worker return self._get_arg("constructor_args", False) 965*da0073e9SAndroid Build Coastguard Worker 966*da0073e9SAndroid Build Coastguard Worker def test_context_manager(self, test_case, device): 967*da0073e9SAndroid Build Coastguard Worker kwargs = {"device": device, "dtype": torch.double} 968*da0073e9SAndroid Build Coastguard Worker module = self.constructor(*self.constructor_args).to(**kwargs) 969*da0073e9SAndroid Build Coastguard Worker if "Embedding" in self.get_name(): 970*da0073e9SAndroid Build Coastguard Worker kwargs["dtype"] = torch.long 971*da0073e9SAndroid Build Coastguard Worker input = self._get_input().to(**kwargs) 972*da0073e9SAndroid Build Coastguard Worker if len(input.shape) == 0 or input.shape[0] == 0: 973*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest( 974*da0073e9SAndroid Build Coastguard Worker "Can't get per sample gradients when no batch dim or batch dim is 0" 975*da0073e9SAndroid Build Coastguard Worker ) 976*da0073e9SAndroid Build Coastguard Worker if self.constructor == torch.nn.Linear and len(input.shape) == 1: 977*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest( 978*da0073e9SAndroid Build Coastguard Worker "Can't get per sample gradients for input of rank 1" 979*da0073e9SAndroid Build Coastguard Worker ) 980*da0073e9SAndroid Build Coastguard Worker test_case._do_test(module, input) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker def test_context_manager_multiple_inputs(self, test_case, device): 983*da0073e9SAndroid Build Coastguard Worker module = self.constructor(*self.constructor_args).to(device) 984*da0073e9SAndroid Build Coastguard Worker input = self._get_input() 985*da0073e9SAndroid Build Coastguard Worker if len(input.shape) == 0 or input.shape[0] == 0: 986*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest( 987*da0073e9SAndroid Build Coastguard Worker "Can't get per sample gradients when no batch dim or batch dim is 0" 988*da0073e9SAndroid Build Coastguard Worker ) 989*da0073e9SAndroid Build Coastguard Worker if self.constructor == torch.nn.Linear and len(input.shape) == 1: 990*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest( 991*da0073e9SAndroid Build Coastguard Worker "Can't get per sample gradients for input of rank 1" 992*da0073e9SAndroid Build Coastguard Worker ) 993*da0073e9SAndroid Build Coastguard Worker test_case._do_test_multi_input(module, input) 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Workerdef filter_supported_tests(t): 997*da0073e9SAndroid Build Coastguard Worker supported_modules = [ 998*da0073e9SAndroid Build Coastguard Worker "Linear", 999*da0073e9SAndroid Build Coastguard Worker "Conv1d", 1000*da0073e9SAndroid Build Coastguard Worker "Conv2d", 1001*da0073e9SAndroid Build Coastguard Worker "Conv3d", 1002*da0073e9SAndroid Build Coastguard Worker "Embedding", 1003*da0073e9SAndroid Build Coastguard Worker "LayerNorm", 1004*da0073e9SAndroid Build Coastguard Worker "GroupNorm", 1005*da0073e9SAndroid Build Coastguard Worker "InstanceNorm", 1006*da0073e9SAndroid Build Coastguard Worker ] 1007*da0073e9SAndroid Build Coastguard Worker if "module_name" in t and t["module_name"] in supported_modules: 1008*da0073e9SAndroid Build Coastguard Worker return True 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests 1012*da0073e9SAndroid Build Coastguard Worker# These currently use the legacy nn tests 1013*da0073e9SAndroid Build Coastguard Workersupported_tests = [ 1014*da0073e9SAndroid Build Coastguard Worker t for t in module_tests + new_module_tests if filter_supported_tests(t) 1015*da0073e9SAndroid Build Coastguard Worker] 1016*da0073e9SAndroid Build Coastguard Workerfor test_param in supported_tests: 1017*da0073e9SAndroid Build Coastguard Worker if "constructor" not in test_param: 1018*da0073e9SAndroid Build Coastguard Worker name = test_param.pop("module_name") 1019*da0073e9SAndroid Build Coastguard Worker test_param["constructor"] = getattr(nn, name) 1020*da0073e9SAndroid Build Coastguard Worker decorator = test_param.pop("decorator", lambda test: test) 1021*da0073e9SAndroid Build Coastguard Worker test = ContextManagerTests(**test_param) 1022*da0073e9SAndroid Build Coastguard Worker test_name = test.get_name() 1023*da0073e9SAndroid Build Coastguard Worker if hasattr(TestExpandedWeightModule, test_name): 1024*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Found two tests with the same name: " + test_name) 1025*da0073e9SAndroid Build Coastguard Worker test_name_multi_input = test.get_name() + "_multiple_inputs" 1026*da0073e9SAndroid Build Coastguard Worker if hasattr(TestExpandedWeightModule, test_name_multi_input): 1027*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Found two tests with the same name: " + test_name) 1028*da0073e9SAndroid Build Coastguard Worker if test.test_cpu: 1029*da0073e9SAndroid Build Coastguard Worker setattr( 1030*da0073e9SAndroid Build Coastguard Worker TestExpandedWeightModule, 1031*da0073e9SAndroid Build Coastguard Worker test_name, 1032*da0073e9SAndroid Build Coastguard Worker decorator(lambda self, test=test: test.test_context_manager(self, "cpu")), 1033*da0073e9SAndroid Build Coastguard Worker ) 1034*da0073e9SAndroid Build Coastguard Worker setattr( 1035*da0073e9SAndroid Build Coastguard Worker TestExpandedWeightModule, 1036*da0073e9SAndroid Build Coastguard Worker test_name_multi_input, 1037*da0073e9SAndroid Build Coastguard Worker decorator( 1038*da0073e9SAndroid Build Coastguard Worker lambda self, test=test: test.test_context_manager_multiple_inputs( 1039*da0073e9SAndroid Build Coastguard Worker self, "cpu" 1040*da0073e9SAndroid Build Coastguard Worker ) 1041*da0073e9SAndroid Build Coastguard Worker ), 1042*da0073e9SAndroid Build Coastguard Worker ) 1043*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA and test.test_cuda: 1044*da0073e9SAndroid Build Coastguard Worker # since this checks derivatives, only use double for precision 1045*da0073e9SAndroid Build Coastguard Worker setattr( 1046*da0073e9SAndroid Build Coastguard Worker TestExpandedWeightModule, 1047*da0073e9SAndroid Build Coastguard Worker test_name + "_cuda_double", 1048*da0073e9SAndroid Build Coastguard Worker decorator(lambda self, test=test: test.test_context_manager(self, "cuda")), 1049*da0073e9SAndroid Build Coastguard Worker ) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker# ------------- HELPER FUNCTIONS ----------------- 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Workerdef run_op(op, input, *args, **kwargs): 1055*da0073e9SAndroid Build Coastguard Worker r""" 1056*da0073e9SAndroid Build Coastguard Worker OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative 1057*da0073e9SAndroid Build Coastguard Worker of the weight, not the input, which can't be differentiable since its dtype is int. Calls op, 1058*da0073e9SAndroid Build Coastguard Worker using the special ordering that Embedding's OpInfo expects for that case. 1059*da0073e9SAndroid Build Coastguard Worker """ 1060*da0073e9SAndroid Build Coastguard Worker if op.name == "nn.functional.embedding": 1061*da0073e9SAndroid Build Coastguard Worker return op(args[0], input, **kwargs) 1062*da0073e9SAndroid Build Coastguard Worker else: 1063*da0073e9SAndroid Build Coastguard Worker return op(input, *args, **kwargs) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Workerdef make_expanded_weight(sample_input, batch_size, loss_reduction="sum"): 1067*da0073e9SAndroid Build Coastguard Worker def expanded_weight_or_clone(arg): 1068*da0073e9SAndroid Build Coastguard Worker if is_diff_tensor(arg): 1069*da0073e9SAndroid Build Coastguard Worker return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction) 1070*da0073e9SAndroid Build Coastguard Worker return clone_if_tensor(arg) 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker ew_input = clone_if_tensor(sample_input.input) 1073*da0073e9SAndroid Build Coastguard Worker ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args) 1074*da0073e9SAndroid Build Coastguard Worker ew_kwargs = { 1075*da0073e9SAndroid Build Coastguard Worker name: expanded_weight_or_clone(arg) 1076*da0073e9SAndroid Build Coastguard Worker for (name, arg) in sample_input.kwargs.items() 1077*da0073e9SAndroid Build Coastguard Worker } 1078*da0073e9SAndroid Build Coastguard Worker return ew_input, ew_args, ew_kwargs 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Workerdef supported_inputs(op, sample_inputs, supported_inputs=True): 1082*da0073e9SAndroid Build Coastguard Worker r""" 1083*da0073e9SAndroid Build Coastguard Worker ExpandedWeights currently does not support some use cases when there's no batch dimension or 1084*da0073e9SAndroid Build Coastguard Worker operations that would cause inter-batch operations. Removes all of the cases it cannot deal with 1085*da0073e9SAndroid Build Coastguard Worker """ 1086*da0073e9SAndroid Build Coastguard Worker 1087*da0073e9SAndroid Build Coastguard Worker def filter_fn(input): 1088*da0073e9SAndroid Build Coastguard Worker convolutions = [ 1089*da0073e9SAndroid Build Coastguard Worker "nn.functional.conv1d", 1090*da0073e9SAndroid Build Coastguard Worker "nn.functional.conv2d", 1091*da0073e9SAndroid Build Coastguard Worker "nn.functional.conv3d", 1092*da0073e9SAndroid Build Coastguard Worker ] 1093*da0073e9SAndroid Build Coastguard Worker batched_input_size = dict(zip(convolutions, [3, 4, 5])) 1094*da0073e9SAndroid Build Coastguard Worker if op.name == "nn.functional.linear": 1095*da0073e9SAndroid Build Coastguard Worker is_supported_input = ( 1096*da0073e9SAndroid Build Coastguard Worker input.input.dim() > 1 1097*da0073e9SAndroid Build Coastguard Worker ) # input of rank 1 means no batch dim 1098*da0073e9SAndroid Build Coastguard Worker elif op.name == "nn.functional.layer_norm": 1099*da0073e9SAndroid Build Coastguard Worker normalized_shape = input.args[0] 1100*da0073e9SAndroid Build Coastguard Worker is_supported_input = ( 1101*da0073e9SAndroid Build Coastguard Worker input.input.shape != normalized_shape 1102*da0073e9SAndroid Build Coastguard Worker ) # would cause inter-batch operations 1103*da0073e9SAndroid Build Coastguard Worker elif op.name in convolutions: 1104*da0073e9SAndroid Build Coastguard Worker # currently can't deal with padding computation on Python level 1105*da0073e9SAndroid Build Coastguard Worker is_supported_input = input.input.dim() == batched_input_size[op.name] 1106*da0073e9SAndroid Build Coastguard Worker elif op.name == "nn.functional.embedding": 1107*da0073e9SAndroid Build Coastguard Worker idx = input.args[0] 1108*da0073e9SAndroid Build Coastguard Worker is_supported_input = len(idx.shape) > 1 # there's no batch size 1109*da0073e9SAndroid Build Coastguard Worker else: 1110*da0073e9SAndroid Build Coastguard Worker is_supported_input = True 1111*da0073e9SAndroid Build Coastguard Worker is_supported_input = ( 1112*da0073e9SAndroid Build Coastguard Worker is_supported_input and input.input.shape[0] > 0 1113*da0073e9SAndroid Build Coastguard Worker ) # 0 is not a valid batch size 1114*da0073e9SAndroid Build Coastguard Worker return is_supported_input if supported_inputs else not is_supported_input 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker return [input for input in sample_inputs if filter_fn(input)] 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Workerdef for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs): 1120*da0073e9SAndroid Build Coastguard Worker # get per sample grads by getting derivative for each input in a for loop 1121*da0073e9SAndroid Build Coastguard Worker per_sample_grad = [] 1122*da0073e9SAndroid Build Coastguard Worker for i in range(batch_size): 1123*da0073e9SAndroid Build Coastguard Worker per_sample_input = input[i] 1124*da0073e9SAndroid Build Coastguard Worker result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs)) 1125*da0073e9SAndroid Build Coastguard Worker diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values()) 1126*da0073e9SAndroid Build Coastguard Worker diff_input_list = [ 1127*da0073e9SAndroid Build Coastguard Worker i 1128*da0073e9SAndroid Build Coastguard Worker for i in diff_input_list 1129*da0073e9SAndroid Build Coastguard Worker if isinstance(i, torch.Tensor) and i.requires_grad 1130*da0073e9SAndroid Build Coastguard Worker ] 1131*da0073e9SAndroid Build Coastguard Worker per_sample_grad.append( 1132*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1133*da0073e9SAndroid Build Coastguard Worker result, diff_input_list, torch.ones_like(result), allow_unused=True 1134*da0073e9SAndroid Build Coastguard Worker ) 1135*da0073e9SAndroid Build Coastguard Worker ) 1136*da0073e9SAndroid Build Coastguard Worker if len(per_sample_grad) == batch_size: 1137*da0073e9SAndroid Build Coastguard Worker per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad)) 1138*da0073e9SAndroid Build Coastguard Worker return per_sample_grad 1139*da0073e9SAndroid Build Coastguard Worker 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Workerdef is_diff_tensor(t): 1142*da0073e9SAndroid Build Coastguard Worker return isinstance(t, ExpandedWeight) or ( 1143*da0073e9SAndroid Build Coastguard Worker isinstance(t, torch.Tensor) and t.requires_grad 1144*da0073e9SAndroid Build Coastguard Worker ) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Workerdef clone_if_tensor(t): 1148*da0073e9SAndroid Build Coastguard Worker if isinstance(t, torch.Tensor): 1149*da0073e9SAndroid Build Coastguard Worker res = torch.clone(t).detach() 1150*da0073e9SAndroid Build Coastguard Worker res.requires_grad_(t.requires_grad) 1151*da0073e9SAndroid Build Coastguard Worker return res 1152*da0073e9SAndroid Build Coastguard Worker else: 1153*da0073e9SAndroid Build Coastguard Worker return t 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestExpandedWeightHelperFunction, globals()) 1157*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestExpandedWeightFunctional, globals()) 1158*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestExpandedWeightModule, globals()) 1159*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1160*da0073e9SAndroid Build Coastguard Worker run_tests() 1161