1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Workerimport math 3*da0073e9SAndroid Build Coastguard Workerimport random 4*da0073e9SAndroid Build Coastguard Workerimport string 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce 7*da0073e9SAndroid Build Coastguard Workerfrom operator import mul 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.init as init 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 13*da0073e9SAndroid Build Coastguard Worker run_tests, 14*da0073e9SAndroid Build Coastguard Worker skipIfNoLapack, 15*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 16*da0073e9SAndroid Build Coastguard Worker slowTest, 17*da0073e9SAndroid Build Coastguard Worker TEST_SCIPY, 18*da0073e9SAndroid Build Coastguard Worker TestCase, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 23*da0073e9SAndroid Build Coastguard Worker from scipy import stats 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerclass TestNNInit(TestCase): 27*da0073e9SAndroid Build Coastguard Worker def setUp(self): 28*da0073e9SAndroid Build Coastguard Worker super().setUp() 29*da0073e9SAndroid Build Coastguard Worker random.seed(123) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker def _is_normal(self, tensor, mean, std): 32*da0073e9SAndroid Build Coastguard Worker samples = tensor.view(-1).tolist() 33*da0073e9SAndroid Build Coastguard Worker p_value = stats.kstest(samples, "norm", args=(mean, std))[1] 34*da0073e9SAndroid Build Coastguard Worker return p_value > 0.0001 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker def _is_trunc_normal(self, tensor, mean, std, a, b): 37*da0073e9SAndroid Build Coastguard Worker # scipy's trunc norm is suited for data drawn from N(0, 1), 38*da0073e9SAndroid Build Coastguard Worker # so we need to transform our data to test it using scipy. 39*da0073e9SAndroid Build Coastguard Worker z_samples = (tensor.view(-1) - mean) / std 40*da0073e9SAndroid Build Coastguard Worker z_samples = z_samples.tolist() 41*da0073e9SAndroid Build Coastguard Worker a0 = (a - mean) / std 42*da0073e9SAndroid Build Coastguard Worker b0 = (b - mean) / std 43*da0073e9SAndroid Build Coastguard Worker p_value = stats.kstest(z_samples, "truncnorm", args=(a0, b0))[1] 44*da0073e9SAndroid Build Coastguard Worker return p_value > 0.0001 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def _is_uniform(self, tensor, a, b): 47*da0073e9SAndroid Build Coastguard Worker samples = tensor.view(-1).tolist() 48*da0073e9SAndroid Build Coastguard Worker p_value = stats.kstest(samples, "uniform", args=(a, (b - a)))[1] 49*da0073e9SAndroid Build Coastguard Worker return p_value > 0.0001 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def _create_random_nd_tensor(self, dims, size_min, size_max): 52*da0073e9SAndroid Build Coastguard Worker size = [random.randint(size_min, size_max) for _ in range(dims)] 53*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(size) 54*da0073e9SAndroid Build Coastguard Worker return tensor 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def _random_float(self, a, b): 57*da0073e9SAndroid Build Coastguard Worker return (b - a) * random.random() + a 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def test_calculate_gain_linear(self): 60*da0073e9SAndroid Build Coastguard Worker for fn in [ 61*da0073e9SAndroid Build Coastguard Worker "linear", 62*da0073e9SAndroid Build Coastguard Worker "conv1d", 63*da0073e9SAndroid Build Coastguard Worker "conv2d", 64*da0073e9SAndroid Build Coastguard Worker "conv3d", 65*da0073e9SAndroid Build Coastguard Worker "conv_transpose2d", 66*da0073e9SAndroid Build Coastguard Worker "conv_transpose2d", 67*da0073e9SAndroid Build Coastguard Worker "conv_transpose3d", 68*da0073e9SAndroid Build Coastguard Worker ]: 69*da0073e9SAndroid Build Coastguard Worker gain = init.calculate_gain(fn) 70*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def test_calculate_gain_nonlinear(self): 73*da0073e9SAndroid Build Coastguard Worker for fn in ["sigmoid", "tanh", "relu", "leaky_relu"]: 74*da0073e9SAndroid Build Coastguard Worker gain = init.calculate_gain(fn) 75*da0073e9SAndroid Build Coastguard Worker if fn == "sigmoid": 76*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1) 77*da0073e9SAndroid Build Coastguard Worker elif fn == "tanh": # 5 / 3 78*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1.6666666666666667) 79*da0073e9SAndroid Build Coastguard Worker elif fn == "relu": # sqrt(2) 80*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1.4142135623730951) 81*da0073e9SAndroid Build Coastguard Worker elif fn == "leaky_relu": # sqrt(2 / 1 + slope^2)) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1.4141428569978354) 83*da0073e9SAndroid Build Coastguard Worker elif fn == "selu": 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 0.75) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def test_calculate_gain_leaky_relu(self): 87*da0073e9SAndroid Build Coastguard Worker for param in [None, 0, 0.01, 10]: 88*da0073e9SAndroid Build Coastguard Worker gain = init.calculate_gain("leaky_relu", param) 89*da0073e9SAndroid Build Coastguard Worker if param is None: # Default slope is 0.01 90*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1.4141428569978354) 91*da0073e9SAndroid Build Coastguard Worker elif param == 0: # No slope = same gain as normal ReLU 92*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1.4142135623730951) 93*da0073e9SAndroid Build Coastguard Worker elif param == 0.01: 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 1.4141428569978354) 95*da0073e9SAndroid Build Coastguard Worker elif param == 10: 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gain, 0.14071950894605836) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def test_calculate_gain_leaky_relu_only_accepts_numbers(self): 99*da0073e9SAndroid Build Coastguard Worker for param in [True, [1], {"a": "b"}]: 100*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 101*da0073e9SAndroid Build Coastguard Worker init.calculate_gain("leaky_relu", param) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def test_calculate_gain_only_accepts_valid_nonlinearities(self): 104*da0073e9SAndroid Build Coastguard Worker for n in [2, 5, 25]: 105*da0073e9SAndroid Build Coastguard Worker # Generate random strings of lengths that definitely aren't supported 106*da0073e9SAndroid Build Coastguard Worker random_string = "".join( 107*da0073e9SAndroid Build Coastguard Worker [random.choice(string.ascii_lowercase) for i in range(n)] 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 110*da0073e9SAndroid Build Coastguard Worker init.calculate_gain(random_string) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 113*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 114*da0073e9SAndroid Build Coastguard Worker def test_uniform(self): 115*da0073e9SAndroid Build Coastguard Worker for dims in [1, 2, 4]: 116*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) 117*da0073e9SAndroid Build Coastguard Worker a = self._random_float(-3, 3) 118*da0073e9SAndroid Build Coastguard Worker b = a + self._random_float(1, 5) 119*da0073e9SAndroid Build Coastguard Worker init.uniform_(input_tensor, a=a, b=b) 120*da0073e9SAndroid Build Coastguard Worker assert self._is_uniform(input_tensor, a, b) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 123*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 124*da0073e9SAndroid Build Coastguard Worker def test_normal(self): 125*da0073e9SAndroid Build Coastguard Worker for dims in [1, 2, 4]: 126*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) 127*da0073e9SAndroid Build Coastguard Worker mean = self._random_float(-3, 3) 128*da0073e9SAndroid Build Coastguard Worker std = self._random_float(1, 5) 129*da0073e9SAndroid Build Coastguard Worker init.normal_(input_tensor, mean=mean, std=std) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker assert self._is_normal(input_tensor, mean, std) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 134*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 135*da0073e9SAndroid Build Coastguard Worker def test_trunc_normal(self): 136*da0073e9SAndroid Build Coastguard Worker for dims in [1, 2, 4]: 137*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) 138*da0073e9SAndroid Build Coastguard Worker mean = self._random_float(-3, 3) 139*da0073e9SAndroid Build Coastguard Worker std = self._random_float(0.01, 1) 140*da0073e9SAndroid Build Coastguard Worker a = self._random_float(mean - 2 * std, mean) 141*da0073e9SAndroid Build Coastguard Worker b = self._random_float(mean, mean + 2 * std) 142*da0073e9SAndroid Build Coastguard Worker init.trunc_normal_(input_tensor, mean=mean, std=std, a=a, b=b) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker assert self._is_trunc_normal(input_tensor, mean, std, a, b) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 147*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 148*da0073e9SAndroid Build Coastguard Worker def test_trunc_normal_generator(self): 149*da0073e9SAndroid Build Coastguard Worker gen = torch.Generator() 150*da0073e9SAndroid Build Coastguard Worker gen.manual_seed(42) 151*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.empty(5) 152*da0073e9SAndroid Build Coastguard Worker init.trunc_normal_(input_tensor, generator=gen) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker ref = torch.empty(5) 155*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 156*da0073e9SAndroid Build Coastguard Worker init.trunc_normal_(ref) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_tensor, ref) 159*da0073e9SAndroid Build Coastguard Worker assert self._is_trunc_normal(input_tensor, mean=0, std=1, a=0, b=1) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def test_constant(self): 162*da0073e9SAndroid Build Coastguard Worker for dims in [1, 2, 4]: 163*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5) 164*da0073e9SAndroid Build Coastguard Worker val = self._random_float(1, 10) 165*da0073e9SAndroid Build Coastguard Worker init.constant_(input_tensor, val) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker def test_ones_and_zeros(self): 170*da0073e9SAndroid Build Coastguard Worker for init_fn_, val in zip([init.ones_, init.zeros_], [1, 0]): 171*da0073e9SAndroid Build Coastguard Worker for dims in [1, 2, 4]: 172*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor( 173*da0073e9SAndroid Build Coastguard Worker dims, size_min=1, size_max=5 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker init_fn_(input_tensor) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def test_eye(self): 180*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5) 181*da0073e9SAndroid Build Coastguard Worker init.eye_(input_tensor) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker # Check every single element 184*da0073e9SAndroid Build Coastguard Worker for i in range(input_tensor.size(0)): 185*da0073e9SAndroid Build Coastguard Worker for j in range(input_tensor.size(1)): 186*da0073e9SAndroid Build Coastguard Worker if i == j: 187*da0073e9SAndroid Build Coastguard Worker assert input_tensor[i][j] == 1 188*da0073e9SAndroid Build Coastguard Worker else: 189*da0073e9SAndroid Build Coastguard Worker assert input_tensor[i][j] == 0 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker def test_eye_only_works_on_2d_inputs(self): 192*da0073e9SAndroid Build Coastguard Worker for dims in [1, 3]: 193*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 194*da0073e9SAndroid Build Coastguard Worker tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) 195*da0073e9SAndroid Build Coastguard Worker init.eye_(tensor) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def test_dirac_properties(self): 198*da0073e9SAndroid Build Coastguard Worker for dims in [3, 4, 5]: 199*da0073e9SAndroid Build Coastguard Worker for groups in [1, 2, 3]: 200*da0073e9SAndroid Build Coastguard Worker # prepare random tensor with random sizes, but fits groups 201*da0073e9SAndroid Build Coastguard Worker a, c, d, e = (random.randint(1, 5) for _ in range(4)) 202*da0073e9SAndroid Build Coastguard Worker b = random.randint( 203*da0073e9SAndroid Build Coastguard Worker 1, 5 * groups 204*da0073e9SAndroid Build Coastguard Worker ) # same range as a*groups but all range allowed 205*da0073e9SAndroid Build Coastguard Worker # make sure first dim divides by groups 206*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.randn((a * groups, b, c, d, e)[:dims]) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker init.dirac_(input_tensor, groups) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker c_out, c_in = input_tensor.size(0) // groups, input_tensor.size(1) 211*da0073e9SAndroid Build Coastguard Worker min_d = min(c_out, c_in) 212*da0073e9SAndroid Build Coastguard Worker # Check number of nonzeros is equivalent to smallest dim (for each group) 213*da0073e9SAndroid Build Coastguard Worker assert torch.nonzero(input_tensor).size(0) == min_d * groups 214*da0073e9SAndroid Build Coastguard Worker # Check sum of values (can have precision issues, hence assertEqual) is also equivalent 215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_tensor.sum(), min_d * groups) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def test_dirac_identity(self): 218*da0073e9SAndroid Build Coastguard Worker for groups in [1, 3]: 219*da0073e9SAndroid Build Coastguard Worker batch, in_c, out_c, size, kernel_size = ( 220*da0073e9SAndroid Build Coastguard Worker 8, 221*da0073e9SAndroid Build Coastguard Worker 3, 222*da0073e9SAndroid Build Coastguard Worker 9, 223*da0073e9SAndroid Build Coastguard Worker 5, 224*da0073e9SAndroid Build Coastguard Worker 3, 225*da0073e9SAndroid Build Coastguard Worker ) # in_c, out_c must divide by groups 226*da0073e9SAndroid Build Coastguard Worker eff_out_c = out_c // groups 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker # Test 1D 229*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(batch, in_c, size) 230*da0073e9SAndroid Build Coastguard Worker filter_var = torch.zeros(eff_out_c, in_c, kernel_size) 231*da0073e9SAndroid Build Coastguard Worker filter_var = torch.cat([filter_var] * groups) 232*da0073e9SAndroid Build Coastguard Worker init.dirac_(filter_var, groups) 233*da0073e9SAndroid Build Coastguard Worker output_var = F.conv1d(input_var, filter_var) 234*da0073e9SAndroid Build Coastguard Worker input_tensor, output_tensor = ( 235*da0073e9SAndroid Build Coastguard Worker input_var.data, 236*da0073e9SAndroid Build Coastguard Worker output_var.data, 237*da0073e9SAndroid Build Coastguard Worker ) # Variables do not support nonzero 238*da0073e9SAndroid Build Coastguard Worker for g in range(groups): 239*da0073e9SAndroid Build Coastguard Worker # Assert in_c outputs are preserved (per each group) 240*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 241*da0073e9SAndroid Build Coastguard Worker input_tensor[:, :, 1:-1], 242*da0073e9SAndroid Build Coastguard Worker output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :], 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker # Assert extra outputs are 0 245*da0073e9SAndroid Build Coastguard Worker assert ( 246*da0073e9SAndroid Build Coastguard Worker torch.nonzero( 247*da0073e9SAndroid Build Coastguard Worker output_tensor[:, eff_out_c * g + in_c : eff_out_c * (g + 1), :] 248*da0073e9SAndroid Build Coastguard Worker ).numel() 249*da0073e9SAndroid Build Coastguard Worker == 0 250*da0073e9SAndroid Build Coastguard Worker ) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker # Test 2D 253*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(batch, in_c, size, size) 254*da0073e9SAndroid Build Coastguard Worker filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size) 255*da0073e9SAndroid Build Coastguard Worker filter_var = torch.cat([filter_var] * groups) 256*da0073e9SAndroid Build Coastguard Worker init.dirac_(filter_var, groups) 257*da0073e9SAndroid Build Coastguard Worker output_var = F.conv2d(input_var, filter_var) 258*da0073e9SAndroid Build Coastguard Worker input_tensor, output_tensor = ( 259*da0073e9SAndroid Build Coastguard Worker input_var.data, 260*da0073e9SAndroid Build Coastguard Worker output_var.data, 261*da0073e9SAndroid Build Coastguard Worker ) # Variables do not support nonzero 262*da0073e9SAndroid Build Coastguard Worker for g in range(groups): 263*da0073e9SAndroid Build Coastguard Worker # Assert in_c outputs are preserved (per each group) 264*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 265*da0073e9SAndroid Build Coastguard Worker input_tensor[:, :, 1:-1, 1:-1], 266*da0073e9SAndroid Build Coastguard Worker output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :], 267*da0073e9SAndroid Build Coastguard Worker ) 268*da0073e9SAndroid Build Coastguard Worker # Assert extra outputs are 0 269*da0073e9SAndroid Build Coastguard Worker assert ( 270*da0073e9SAndroid Build Coastguard Worker torch.nonzero( 271*da0073e9SAndroid Build Coastguard Worker output_tensor[ 272*da0073e9SAndroid Build Coastguard Worker :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, : 273*da0073e9SAndroid Build Coastguard Worker ] 274*da0073e9SAndroid Build Coastguard Worker ).numel() 275*da0073e9SAndroid Build Coastguard Worker == 0 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker # Test 3D 279*da0073e9SAndroid Build Coastguard Worker input_var = torch.randn(batch, in_c, size, size, size) 280*da0073e9SAndroid Build Coastguard Worker filter_var = torch.zeros( 281*da0073e9SAndroid Build Coastguard Worker eff_out_c, in_c, kernel_size, kernel_size, kernel_size 282*da0073e9SAndroid Build Coastguard Worker ) 283*da0073e9SAndroid Build Coastguard Worker filter_var = torch.cat([filter_var] * groups) 284*da0073e9SAndroid Build Coastguard Worker init.dirac_(filter_var, groups) 285*da0073e9SAndroid Build Coastguard Worker output_var = F.conv3d(input_var, filter_var) 286*da0073e9SAndroid Build Coastguard Worker input_tensor, output_tensor = input_var.data, output_var.data 287*da0073e9SAndroid Build Coastguard Worker for g in range(groups): 288*da0073e9SAndroid Build Coastguard Worker # Assert in_c outputs are preserved (per each group) 289*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 290*da0073e9SAndroid Build Coastguard Worker input_tensor[:, :, 1:-1, 1:-1, 1:-1], 291*da0073e9SAndroid Build Coastguard Worker output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :, :], 292*da0073e9SAndroid Build Coastguard Worker ) 293*da0073e9SAndroid Build Coastguard Worker # Assert extra outputs are 0 294*da0073e9SAndroid Build Coastguard Worker assert ( 295*da0073e9SAndroid Build Coastguard Worker torch.nonzero( 296*da0073e9SAndroid Build Coastguard Worker output_tensor[ 297*da0073e9SAndroid Build Coastguard Worker :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, :, : 298*da0073e9SAndroid Build Coastguard Worker ] 299*da0073e9SAndroid Build Coastguard Worker ).numel() 300*da0073e9SAndroid Build Coastguard Worker == 0 301*da0073e9SAndroid Build Coastguard Worker ) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker def test_dirac_only_works_on_3_4_5d_inputs(self): 304*da0073e9SAndroid Build Coastguard Worker for dims in [1, 2, 6]: 305*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 306*da0073e9SAndroid Build Coastguard Worker tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) 307*da0073e9SAndroid Build Coastguard Worker init.dirac_(tensor) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self): 310*da0073e9SAndroid Build Coastguard Worker for dims in [0, 1]: 311*da0073e9SAndroid Build Coastguard Worker tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 312*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 313*da0073e9SAndroid Build Coastguard Worker init.xavier_uniform_(tensor) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker def test_xavier_normal_errors_on_inputs_smaller_than_2d(self): 316*da0073e9SAndroid Build Coastguard Worker for dims in [0, 1]: 317*da0073e9SAndroid Build Coastguard Worker tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 318*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 319*da0073e9SAndroid Build Coastguard Worker init.xavier_normal_(tensor) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 322*da0073e9SAndroid Build Coastguard Worker @slowTest 323*da0073e9SAndroid Build Coastguard Worker def test_xavier_uniform(self): 324*da0073e9SAndroid Build Coastguard Worker for use_gain in [True, False]: 325*da0073e9SAndroid Build Coastguard Worker for dims in [2, 4]: 326*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor( 327*da0073e9SAndroid Build Coastguard Worker dims, size_min=20, size_max=25 328*da0073e9SAndroid Build Coastguard Worker ) 329*da0073e9SAndroid Build Coastguard Worker gain = 1 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker if use_gain: 332*da0073e9SAndroid Build Coastguard Worker gain = self._random_float(0.1, 2) 333*da0073e9SAndroid Build Coastguard Worker init.xavier_uniform_(input_tensor, gain=gain) 334*da0073e9SAndroid Build Coastguard Worker else: 335*da0073e9SAndroid Build Coastguard Worker init.xavier_uniform_(input_tensor) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker fan_in = input_tensor.size(1) 338*da0073e9SAndroid Build Coastguard Worker fan_out = input_tensor.size(0) 339*da0073e9SAndroid Build Coastguard Worker if input_tensor.dim() > 2: 340*da0073e9SAndroid Build Coastguard Worker fan_in *= input_tensor[0, 0].numel() 341*da0073e9SAndroid Build Coastguard Worker fan_out *= input_tensor[0, 0].numel() 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) 344*da0073e9SAndroid Build Coastguard Worker bounds = expected_std * math.sqrt(3) 345*da0073e9SAndroid Build Coastguard Worker assert self._is_uniform(input_tensor, -bounds, bounds) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 348*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 349*da0073e9SAndroid Build Coastguard Worker def test_xavier_normal(self): 350*da0073e9SAndroid Build Coastguard Worker for use_gain in [True, False]: 351*da0073e9SAndroid Build Coastguard Worker for dims in [2, 4]: 352*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor( 353*da0073e9SAndroid Build Coastguard Worker dims, size_min=20, size_max=25 354*da0073e9SAndroid Build Coastguard Worker ) 355*da0073e9SAndroid Build Coastguard Worker gain = 1 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker if use_gain: 358*da0073e9SAndroid Build Coastguard Worker gain = self._random_float(0.1, 2) 359*da0073e9SAndroid Build Coastguard Worker init.xavier_normal_(input_tensor, gain=gain) 360*da0073e9SAndroid Build Coastguard Worker else: 361*da0073e9SAndroid Build Coastguard Worker init.xavier_normal_(input_tensor) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker fan_in = input_tensor.size(1) 364*da0073e9SAndroid Build Coastguard Worker fan_out = input_tensor.size(0) 365*da0073e9SAndroid Build Coastguard Worker if input_tensor.dim() > 2: 366*da0073e9SAndroid Build Coastguard Worker fan_in *= input_tensor[0, 0].numel() 367*da0073e9SAndroid Build Coastguard Worker fan_out *= input_tensor[0, 0].numel() 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) 370*da0073e9SAndroid Build Coastguard Worker assert self._is_normal(input_tensor, 0, expected_std) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self): 373*da0073e9SAndroid Build Coastguard Worker for dims in [0, 1]: 374*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 375*da0073e9SAndroid Build Coastguard Worker tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 376*da0073e9SAndroid Build Coastguard Worker init.kaiming_uniform_(tensor) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self): 379*da0073e9SAndroid Build Coastguard Worker for dims in [0, 1]: 380*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 381*da0073e9SAndroid Build Coastguard Worker tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 382*da0073e9SAndroid Build Coastguard Worker init.kaiming_normal_(tensor) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker def test_kaiming_uniform_warning_on_0element_tensor(self): 385*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(0, 1) 386*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 387*da0073e9SAndroid Build Coastguard Worker UserWarning, "Initializing zero-element tensors is a no-op" 388*da0073e9SAndroid Build Coastguard Worker ): 389*da0073e9SAndroid Build Coastguard Worker _ = init.kaiming_uniform_(tensor) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker def test_kaiming_normal_warning_on_0element_tensor(self): 392*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(0, 1) 393*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 394*da0073e9SAndroid Build Coastguard Worker UserWarning, "Initializing zero-element tensors is a no-op" 395*da0073e9SAndroid Build Coastguard Worker ): 396*da0073e9SAndroid Build Coastguard Worker _ = init.kaiming_normal_(tensor) 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 399*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 400*da0073e9SAndroid Build Coastguard Worker def test_kaiming_uniform(self): 401*da0073e9SAndroid Build Coastguard Worker for use_a in [True, False]: 402*da0073e9SAndroid Build Coastguard Worker for dims in [2, 4]: 403*da0073e9SAndroid Build Coastguard Worker for mode in ["fan_in", "fan_out"]: 404*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor( 405*da0073e9SAndroid Build Coastguard Worker dims, size_min=20, size_max=25 406*da0073e9SAndroid Build Coastguard Worker ) 407*da0073e9SAndroid Build Coastguard Worker if use_a: 408*da0073e9SAndroid Build Coastguard Worker a = self._random_float(0.1, 2) 409*da0073e9SAndroid Build Coastguard Worker init.kaiming_uniform_(input_tensor, a=a, mode=mode) 410*da0073e9SAndroid Build Coastguard Worker else: 411*da0073e9SAndroid Build Coastguard Worker a = 0 412*da0073e9SAndroid Build Coastguard Worker init.kaiming_uniform_(input_tensor, mode=mode) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker fan_in = input_tensor.size(1) 415*da0073e9SAndroid Build Coastguard Worker fan_out = input_tensor.size(0) 416*da0073e9SAndroid Build Coastguard Worker if input_tensor.dim() > 2: 417*da0073e9SAndroid Build Coastguard Worker fan_in *= input_tensor[0, 0].numel() 418*da0073e9SAndroid Build Coastguard Worker fan_out *= input_tensor[0, 0].numel() 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker if mode == "fan_in": 421*da0073e9SAndroid Build Coastguard Worker n = fan_in 422*da0073e9SAndroid Build Coastguard Worker else: 423*da0073e9SAndroid Build Coastguard Worker n = fan_out 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) 426*da0073e9SAndroid Build Coastguard Worker bounds = expected_std * math.sqrt(3.0) 427*da0073e9SAndroid Build Coastguard Worker assert self._is_uniform(input_tensor, -bounds, bounds) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 430*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 431*da0073e9SAndroid Build Coastguard Worker def test_kaiming_normal(self): 432*da0073e9SAndroid Build Coastguard Worker for use_a in [True, False]: 433*da0073e9SAndroid Build Coastguard Worker for dims in [2, 4]: 434*da0073e9SAndroid Build Coastguard Worker for mode in ["fan_in", "fan_out"]: 435*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor( 436*da0073e9SAndroid Build Coastguard Worker dims, size_min=20, size_max=25 437*da0073e9SAndroid Build Coastguard Worker ) 438*da0073e9SAndroid Build Coastguard Worker if use_a: 439*da0073e9SAndroid Build Coastguard Worker a = self._random_float(0.1, 2) 440*da0073e9SAndroid Build Coastguard Worker init.kaiming_normal_(input_tensor, a=a, mode=mode) 441*da0073e9SAndroid Build Coastguard Worker else: 442*da0073e9SAndroid Build Coastguard Worker a = 0 443*da0073e9SAndroid Build Coastguard Worker init.kaiming_normal_(input_tensor, mode=mode) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker fan_in = input_tensor.size(1) 446*da0073e9SAndroid Build Coastguard Worker fan_out = input_tensor.size(0) 447*da0073e9SAndroid Build Coastguard Worker if input_tensor.dim() > 2: 448*da0073e9SAndroid Build Coastguard Worker fan_in *= input_tensor[0, 0].numel() 449*da0073e9SAndroid Build Coastguard Worker fan_out *= input_tensor[0, 0].numel() 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker if mode == "fan_in": 452*da0073e9SAndroid Build Coastguard Worker n = fan_in 453*da0073e9SAndroid Build Coastguard Worker else: 454*da0073e9SAndroid Build Coastguard Worker n = fan_out 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) 457*da0073e9SAndroid Build Coastguard Worker assert self._is_normal(input_tensor, 0, expected_std) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker def test_sparse_only_works_on_2d_inputs(self): 460*da0073e9SAndroid Build Coastguard Worker for dims in [1, 3]: 461*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 462*da0073e9SAndroid Build Coastguard Worker sparsity = self._random_float(0.1, 0.9) 463*da0073e9SAndroid Build Coastguard Worker tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) 464*da0073e9SAndroid Build Coastguard Worker init.sparse_(tensor, sparsity) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 467*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 468*da0073e9SAndroid Build Coastguard Worker def test_sparse_default_std(self): 469*da0073e9SAndroid Build Coastguard Worker for use_random_std in [True, False]: 470*da0073e9SAndroid Build Coastguard Worker input_tensor = self._create_random_nd_tensor(2, size_min=30, size_max=35) 471*da0073e9SAndroid Build Coastguard Worker rows, cols = input_tensor.size(0), input_tensor.size(1) 472*da0073e9SAndroid Build Coastguard Worker sparsity = self._random_float(0.1, 0.2) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker std = 0.01 # default std 475*da0073e9SAndroid Build Coastguard Worker if use_random_std: 476*da0073e9SAndroid Build Coastguard Worker std = self._random_float(0.01, 0.2) 477*da0073e9SAndroid Build Coastguard Worker init.sparse_(input_tensor, sparsity=sparsity, std=std) 478*da0073e9SAndroid Build Coastguard Worker else: 479*da0073e9SAndroid Build Coastguard Worker init.sparse_(input_tensor, sparsity=sparsity) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker for col_idx in range(input_tensor.size(1)): 482*da0073e9SAndroid Build Coastguard Worker column = input_tensor[:, col_idx] 483*da0073e9SAndroid Build Coastguard Worker assert column[column == 0].nelement() >= math.ceil(sparsity * rows) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker assert self._is_normal(input_tensor[input_tensor != 0], 0, std) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 488*da0073e9SAndroid Build Coastguard Worker def test_orthogonal(self): 489*da0073e9SAndroid Build Coastguard Worker for use_gain in [True, False]: 490*da0073e9SAndroid Build Coastguard Worker for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]: 491*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.zeros(tensor_size) 492*da0073e9SAndroid Build Coastguard Worker gain = 1.0 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker if use_gain: 495*da0073e9SAndroid Build Coastguard Worker gain = self._random_float(0.1, 2) 496*da0073e9SAndroid Build Coastguard Worker init.orthogonal_(input_tensor, gain=gain) 497*da0073e9SAndroid Build Coastguard Worker else: 498*da0073e9SAndroid Build Coastguard Worker init.orthogonal_(input_tensor) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker rows, cols = tensor_size[0], reduce(mul, tensor_size[1:]) 501*da0073e9SAndroid Build Coastguard Worker flattened_tensor = input_tensor.view(rows, cols) 502*da0073e9SAndroid Build Coastguard Worker if rows > cols: 503*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 504*da0073e9SAndroid Build Coastguard Worker torch.mm(flattened_tensor.t(), flattened_tensor), 505*da0073e9SAndroid Build Coastguard Worker torch.eye(cols) * gain**2, 506*da0073e9SAndroid Build Coastguard Worker atol=1e-6, 507*da0073e9SAndroid Build Coastguard Worker rtol=0, 508*da0073e9SAndroid Build Coastguard Worker ) 509*da0073e9SAndroid Build Coastguard Worker else: 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 511*da0073e9SAndroid Build Coastguard Worker torch.mm(flattened_tensor, flattened_tensor.t()), 512*da0073e9SAndroid Build Coastguard Worker torch.eye(rows) * gain**2, 513*da0073e9SAndroid Build Coastguard Worker atol=1e-6, 514*da0073e9SAndroid Build Coastguard Worker rtol=0, 515*da0073e9SAndroid Build Coastguard Worker ) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker def test_deprecation(self): 518*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker def fn(): 521*da0073e9SAndroid Build Coastguard Worker init.normal(x) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 524*da0073e9SAndroid Build Coastguard Worker FutureWarning, 525*da0073e9SAndroid Build Coastguard Worker "deprecated", 526*da0073e9SAndroid Build Coastguard Worker msg="methods not suffixed with underscore should be deprecated", 527*da0073e9SAndroid Build Coastguard Worker ): 528*da0073e9SAndroid Build Coastguard Worker fn() 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 532*da0073e9SAndroid Build Coastguard Worker run_tests() 533