1# Owner(s): ["module: nn"] 2import math 3import random 4import string 5import unittest 6from functools import reduce 7from operator import mul 8 9import torch 10import torch.nn.functional as F 11import torch.nn.init as init 12from torch.testing._internal.common_utils import ( 13 run_tests, 14 skipIfNoLapack, 15 skipIfTorchDynamo, 16 slowTest, 17 TEST_SCIPY, 18 TestCase, 19) 20 21 22if TEST_SCIPY: 23 from scipy import stats 24 25 26class TestNNInit(TestCase): 27 def setUp(self): 28 super().setUp() 29 random.seed(123) 30 31 def _is_normal(self, tensor, mean, std): 32 samples = tensor.view(-1).tolist() 33 p_value = stats.kstest(samples, "norm", args=(mean, std))[1] 34 return p_value > 0.0001 35 36 def _is_trunc_normal(self, tensor, mean, std, a, b): 37 # scipy's trunc norm is suited for data drawn from N(0, 1), 38 # so we need to transform our data to test it using scipy. 39 z_samples = (tensor.view(-1) - mean) / std 40 z_samples = z_samples.tolist() 41 a0 = (a - mean) / std 42 b0 = (b - mean) / std 43 p_value = stats.kstest(z_samples, "truncnorm", args=(a0, b0))[1] 44 return p_value > 0.0001 45 46 def _is_uniform(self, tensor, a, b): 47 samples = tensor.view(-1).tolist() 48 p_value = stats.kstest(samples, "uniform", args=(a, (b - a)))[1] 49 return p_value > 0.0001 50 51 def _create_random_nd_tensor(self, dims, size_min, size_max): 52 size = [random.randint(size_min, size_max) for _ in range(dims)] 53 tensor = torch.zeros(size) 54 return tensor 55 56 def _random_float(self, a, b): 57 return (b - a) * random.random() + a 58 59 def test_calculate_gain_linear(self): 60 for fn in [ 61 "linear", 62 "conv1d", 63 "conv2d", 64 "conv3d", 65 "conv_transpose2d", 66 "conv_transpose2d", 67 "conv_transpose3d", 68 ]: 69 gain = init.calculate_gain(fn) 70 self.assertEqual(gain, 1) 71 72 def test_calculate_gain_nonlinear(self): 73 for fn in ["sigmoid", "tanh", "relu", "leaky_relu"]: 74 gain = init.calculate_gain(fn) 75 if fn == "sigmoid": 76 self.assertEqual(gain, 1) 77 elif fn == "tanh": # 5 / 3 78 self.assertEqual(gain, 1.6666666666666667) 79 elif fn == "relu": # sqrt(2) 80 self.assertEqual(gain, 1.4142135623730951) 81 elif fn == "leaky_relu": # sqrt(2 / 1 + slope^2)) 82 self.assertEqual(gain, 1.4141428569978354) 83 elif fn == "selu": 84 self.assertEqual(gain, 0.75) 85 86 def test_calculate_gain_leaky_relu(self): 87 for param in [None, 0, 0.01, 10]: 88 gain = init.calculate_gain("leaky_relu", param) 89 if param is None: # Default slope is 0.01 90 self.assertEqual(gain, 1.4141428569978354) 91 elif param == 0: # No slope = same gain as normal ReLU 92 self.assertEqual(gain, 1.4142135623730951) 93 elif param == 0.01: 94 self.assertEqual(gain, 1.4141428569978354) 95 elif param == 10: 96 self.assertEqual(gain, 0.14071950894605836) 97 98 def test_calculate_gain_leaky_relu_only_accepts_numbers(self): 99 for param in [True, [1], {"a": "b"}]: 100 with self.assertRaises(ValueError): 101 init.calculate_gain("leaky_relu", param) 102 103 def test_calculate_gain_only_accepts_valid_nonlinearities(self): 104 for n in [2, 5, 25]: 105 # Generate random strings of lengths that definitely aren't supported 106 random_string = "".join( 107 [random.choice(string.ascii_lowercase) for i in range(n)] 108 ) 109 with self.assertRaises(ValueError): 110 init.calculate_gain(random_string) 111 112 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 113 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 114 def test_uniform(self): 115 for dims in [1, 2, 4]: 116 input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) 117 a = self._random_float(-3, 3) 118 b = a + self._random_float(1, 5) 119 init.uniform_(input_tensor, a=a, b=b) 120 assert self._is_uniform(input_tensor, a, b) 121 122 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 123 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 124 def test_normal(self): 125 for dims in [1, 2, 4]: 126 input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) 127 mean = self._random_float(-3, 3) 128 std = self._random_float(1, 5) 129 init.normal_(input_tensor, mean=mean, std=std) 130 131 assert self._is_normal(input_tensor, mean, std) 132 133 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 134 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 135 def test_trunc_normal(self): 136 for dims in [1, 2, 4]: 137 input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) 138 mean = self._random_float(-3, 3) 139 std = self._random_float(0.01, 1) 140 a = self._random_float(mean - 2 * std, mean) 141 b = self._random_float(mean, mean + 2 * std) 142 init.trunc_normal_(input_tensor, mean=mean, std=std, a=a, b=b) 143 144 assert self._is_trunc_normal(input_tensor, mean, std, a, b) 145 146 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 147 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 148 def test_trunc_normal_generator(self): 149 gen = torch.Generator() 150 gen.manual_seed(42) 151 input_tensor = torch.empty(5) 152 init.trunc_normal_(input_tensor, generator=gen) 153 154 ref = torch.empty(5) 155 torch.manual_seed(42) 156 init.trunc_normal_(ref) 157 158 self.assertEqual(input_tensor, ref) 159 assert self._is_trunc_normal(input_tensor, mean=0, std=1, a=0, b=1) 160 161 def test_constant(self): 162 for dims in [1, 2, 4]: 163 input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5) 164 val = self._random_float(1, 10) 165 init.constant_(input_tensor, val) 166 167 self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) 168 169 def test_ones_and_zeros(self): 170 for init_fn_, val in zip([init.ones_, init.zeros_], [1, 0]): 171 for dims in [1, 2, 4]: 172 input_tensor = self._create_random_nd_tensor( 173 dims, size_min=1, size_max=5 174 ) 175 init_fn_(input_tensor) 176 177 self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) 178 179 def test_eye(self): 180 input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5) 181 init.eye_(input_tensor) 182 183 # Check every single element 184 for i in range(input_tensor.size(0)): 185 for j in range(input_tensor.size(1)): 186 if i == j: 187 assert input_tensor[i][j] == 1 188 else: 189 assert input_tensor[i][j] == 0 190 191 def test_eye_only_works_on_2d_inputs(self): 192 for dims in [1, 3]: 193 with self.assertRaises(ValueError): 194 tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) 195 init.eye_(tensor) 196 197 def test_dirac_properties(self): 198 for dims in [3, 4, 5]: 199 for groups in [1, 2, 3]: 200 # prepare random tensor with random sizes, but fits groups 201 a, c, d, e = (random.randint(1, 5) for _ in range(4)) 202 b = random.randint( 203 1, 5 * groups 204 ) # same range as a*groups but all range allowed 205 # make sure first dim divides by groups 206 input_tensor = torch.randn((a * groups, b, c, d, e)[:dims]) 207 208 init.dirac_(input_tensor, groups) 209 210 c_out, c_in = input_tensor.size(0) // groups, input_tensor.size(1) 211 min_d = min(c_out, c_in) 212 # Check number of nonzeros is equivalent to smallest dim (for each group) 213 assert torch.nonzero(input_tensor).size(0) == min_d * groups 214 # Check sum of values (can have precision issues, hence assertEqual) is also equivalent 215 self.assertEqual(input_tensor.sum(), min_d * groups) 216 217 def test_dirac_identity(self): 218 for groups in [1, 3]: 219 batch, in_c, out_c, size, kernel_size = ( 220 8, 221 3, 222 9, 223 5, 224 3, 225 ) # in_c, out_c must divide by groups 226 eff_out_c = out_c // groups 227 228 # Test 1D 229 input_var = torch.randn(batch, in_c, size) 230 filter_var = torch.zeros(eff_out_c, in_c, kernel_size) 231 filter_var = torch.cat([filter_var] * groups) 232 init.dirac_(filter_var, groups) 233 output_var = F.conv1d(input_var, filter_var) 234 input_tensor, output_tensor = ( 235 input_var.data, 236 output_var.data, 237 ) # Variables do not support nonzero 238 for g in range(groups): 239 # Assert in_c outputs are preserved (per each group) 240 self.assertEqual( 241 input_tensor[:, :, 1:-1], 242 output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :], 243 ) 244 # Assert extra outputs are 0 245 assert ( 246 torch.nonzero( 247 output_tensor[:, eff_out_c * g + in_c : eff_out_c * (g + 1), :] 248 ).numel() 249 == 0 250 ) 251 252 # Test 2D 253 input_var = torch.randn(batch, in_c, size, size) 254 filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size) 255 filter_var = torch.cat([filter_var] * groups) 256 init.dirac_(filter_var, groups) 257 output_var = F.conv2d(input_var, filter_var) 258 input_tensor, output_tensor = ( 259 input_var.data, 260 output_var.data, 261 ) # Variables do not support nonzero 262 for g in range(groups): 263 # Assert in_c outputs are preserved (per each group) 264 self.assertEqual( 265 input_tensor[:, :, 1:-1, 1:-1], 266 output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :], 267 ) 268 # Assert extra outputs are 0 269 assert ( 270 torch.nonzero( 271 output_tensor[ 272 :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, : 273 ] 274 ).numel() 275 == 0 276 ) 277 278 # Test 3D 279 input_var = torch.randn(batch, in_c, size, size, size) 280 filter_var = torch.zeros( 281 eff_out_c, in_c, kernel_size, kernel_size, kernel_size 282 ) 283 filter_var = torch.cat([filter_var] * groups) 284 init.dirac_(filter_var, groups) 285 output_var = F.conv3d(input_var, filter_var) 286 input_tensor, output_tensor = input_var.data, output_var.data 287 for g in range(groups): 288 # Assert in_c outputs are preserved (per each group) 289 self.assertEqual( 290 input_tensor[:, :, 1:-1, 1:-1, 1:-1], 291 output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :, :], 292 ) 293 # Assert extra outputs are 0 294 assert ( 295 torch.nonzero( 296 output_tensor[ 297 :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, :, : 298 ] 299 ).numel() 300 == 0 301 ) 302 303 def test_dirac_only_works_on_3_4_5d_inputs(self): 304 for dims in [1, 2, 6]: 305 with self.assertRaises(ValueError): 306 tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) 307 init.dirac_(tensor) 308 309 def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self): 310 for dims in [0, 1]: 311 tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 312 with self.assertRaises(ValueError): 313 init.xavier_uniform_(tensor) 314 315 def test_xavier_normal_errors_on_inputs_smaller_than_2d(self): 316 for dims in [0, 1]: 317 tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 318 with self.assertRaises(ValueError): 319 init.xavier_normal_(tensor) 320 321 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 322 @slowTest 323 def test_xavier_uniform(self): 324 for use_gain in [True, False]: 325 for dims in [2, 4]: 326 input_tensor = self._create_random_nd_tensor( 327 dims, size_min=20, size_max=25 328 ) 329 gain = 1 330 331 if use_gain: 332 gain = self._random_float(0.1, 2) 333 init.xavier_uniform_(input_tensor, gain=gain) 334 else: 335 init.xavier_uniform_(input_tensor) 336 337 fan_in = input_tensor.size(1) 338 fan_out = input_tensor.size(0) 339 if input_tensor.dim() > 2: 340 fan_in *= input_tensor[0, 0].numel() 341 fan_out *= input_tensor[0, 0].numel() 342 343 expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) 344 bounds = expected_std * math.sqrt(3) 345 assert self._is_uniform(input_tensor, -bounds, bounds) 346 347 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 348 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 349 def test_xavier_normal(self): 350 for use_gain in [True, False]: 351 for dims in [2, 4]: 352 input_tensor = self._create_random_nd_tensor( 353 dims, size_min=20, size_max=25 354 ) 355 gain = 1 356 357 if use_gain: 358 gain = self._random_float(0.1, 2) 359 init.xavier_normal_(input_tensor, gain=gain) 360 else: 361 init.xavier_normal_(input_tensor) 362 363 fan_in = input_tensor.size(1) 364 fan_out = input_tensor.size(0) 365 if input_tensor.dim() > 2: 366 fan_in *= input_tensor[0, 0].numel() 367 fan_out *= input_tensor[0, 0].numel() 368 369 expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) 370 assert self._is_normal(input_tensor, 0, expected_std) 371 372 def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self): 373 for dims in [0, 1]: 374 with self.assertRaises(ValueError): 375 tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 376 init.kaiming_uniform_(tensor) 377 378 def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self): 379 for dims in [0, 1]: 380 with self.assertRaises(ValueError): 381 tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) 382 init.kaiming_normal_(tensor) 383 384 def test_kaiming_uniform_warning_on_0element_tensor(self): 385 tensor = torch.empty(0, 1) 386 with self.assertWarnsRegex( 387 UserWarning, "Initializing zero-element tensors is a no-op" 388 ): 389 _ = init.kaiming_uniform_(tensor) 390 391 def test_kaiming_normal_warning_on_0element_tensor(self): 392 tensor = torch.empty(0, 1) 393 with self.assertWarnsRegex( 394 UserWarning, "Initializing zero-element tensors is a no-op" 395 ): 396 _ = init.kaiming_normal_(tensor) 397 398 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 399 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 400 def test_kaiming_uniform(self): 401 for use_a in [True, False]: 402 for dims in [2, 4]: 403 for mode in ["fan_in", "fan_out"]: 404 input_tensor = self._create_random_nd_tensor( 405 dims, size_min=20, size_max=25 406 ) 407 if use_a: 408 a = self._random_float(0.1, 2) 409 init.kaiming_uniform_(input_tensor, a=a, mode=mode) 410 else: 411 a = 0 412 init.kaiming_uniform_(input_tensor, mode=mode) 413 414 fan_in = input_tensor.size(1) 415 fan_out = input_tensor.size(0) 416 if input_tensor.dim() > 2: 417 fan_in *= input_tensor[0, 0].numel() 418 fan_out *= input_tensor[0, 0].numel() 419 420 if mode == "fan_in": 421 n = fan_in 422 else: 423 n = fan_out 424 425 expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) 426 bounds = expected_std * math.sqrt(3.0) 427 assert self._is_uniform(input_tensor, -bounds, bounds) 428 429 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 430 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 431 def test_kaiming_normal(self): 432 for use_a in [True, False]: 433 for dims in [2, 4]: 434 for mode in ["fan_in", "fan_out"]: 435 input_tensor = self._create_random_nd_tensor( 436 dims, size_min=20, size_max=25 437 ) 438 if use_a: 439 a = self._random_float(0.1, 2) 440 init.kaiming_normal_(input_tensor, a=a, mode=mode) 441 else: 442 a = 0 443 init.kaiming_normal_(input_tensor, mode=mode) 444 445 fan_in = input_tensor.size(1) 446 fan_out = input_tensor.size(0) 447 if input_tensor.dim() > 2: 448 fan_in *= input_tensor[0, 0].numel() 449 fan_out *= input_tensor[0, 0].numel() 450 451 if mode == "fan_in": 452 n = fan_in 453 else: 454 n = fan_out 455 456 expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) 457 assert self._is_normal(input_tensor, 0, expected_std) 458 459 def test_sparse_only_works_on_2d_inputs(self): 460 for dims in [1, 3]: 461 with self.assertRaises(ValueError): 462 sparsity = self._random_float(0.1, 0.9) 463 tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) 464 init.sparse_(tensor, sparsity) 465 466 @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") 467 @skipIfTorchDynamo("scipy.kstest is failing under dynamo") 468 def test_sparse_default_std(self): 469 for use_random_std in [True, False]: 470 input_tensor = self._create_random_nd_tensor(2, size_min=30, size_max=35) 471 rows, cols = input_tensor.size(0), input_tensor.size(1) 472 sparsity = self._random_float(0.1, 0.2) 473 474 std = 0.01 # default std 475 if use_random_std: 476 std = self._random_float(0.01, 0.2) 477 init.sparse_(input_tensor, sparsity=sparsity, std=std) 478 else: 479 init.sparse_(input_tensor, sparsity=sparsity) 480 481 for col_idx in range(input_tensor.size(1)): 482 column = input_tensor[:, col_idx] 483 assert column[column == 0].nelement() >= math.ceil(sparsity * rows) 484 485 assert self._is_normal(input_tensor[input_tensor != 0], 0, std) 486 487 @skipIfNoLapack 488 def test_orthogonal(self): 489 for use_gain in [True, False]: 490 for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]: 491 input_tensor = torch.zeros(tensor_size) 492 gain = 1.0 493 494 if use_gain: 495 gain = self._random_float(0.1, 2) 496 init.orthogonal_(input_tensor, gain=gain) 497 else: 498 init.orthogonal_(input_tensor) 499 500 rows, cols = tensor_size[0], reduce(mul, tensor_size[1:]) 501 flattened_tensor = input_tensor.view(rows, cols) 502 if rows > cols: 503 self.assertEqual( 504 torch.mm(flattened_tensor.t(), flattened_tensor), 505 torch.eye(cols) * gain**2, 506 atol=1e-6, 507 rtol=0, 508 ) 509 else: 510 self.assertEqual( 511 torch.mm(flattened_tensor, flattened_tensor.t()), 512 torch.eye(rows) * gain**2, 513 atol=1e-6, 514 rtol=0, 515 ) 516 517 def test_deprecation(self): 518 x = torch.randn(3, 3) 519 520 def fn(): 521 init.normal(x) 522 523 with self.assertWarnsRegex( 524 FutureWarning, 525 "deprecated", 526 msg="methods not suffixed with underscore should be deprecated", 527 ): 528 fn() 529 530 531if __name__ == "__main__": 532 run_tests() 533