xref: /aosp_15_r20/external/pytorch/test/nn/test_init.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import math
3import random
4import string
5import unittest
6from functools import reduce
7from operator import mul
8
9import torch
10import torch.nn.functional as F
11import torch.nn.init as init
12from torch.testing._internal.common_utils import (
13    run_tests,
14    skipIfNoLapack,
15    skipIfTorchDynamo,
16    slowTest,
17    TEST_SCIPY,
18    TestCase,
19)
20
21
22if TEST_SCIPY:
23    from scipy import stats
24
25
26class TestNNInit(TestCase):
27    def setUp(self):
28        super().setUp()
29        random.seed(123)
30
31    def _is_normal(self, tensor, mean, std):
32        samples = tensor.view(-1).tolist()
33        p_value = stats.kstest(samples, "norm", args=(mean, std))[1]
34        return p_value > 0.0001
35
36    def _is_trunc_normal(self, tensor, mean, std, a, b):
37        # scipy's trunc norm is suited for data drawn from N(0, 1),
38        # so we need to transform our data to test it using scipy.
39        z_samples = (tensor.view(-1) - mean) / std
40        z_samples = z_samples.tolist()
41        a0 = (a - mean) / std
42        b0 = (b - mean) / std
43        p_value = stats.kstest(z_samples, "truncnorm", args=(a0, b0))[1]
44        return p_value > 0.0001
45
46    def _is_uniform(self, tensor, a, b):
47        samples = tensor.view(-1).tolist()
48        p_value = stats.kstest(samples, "uniform", args=(a, (b - a)))[1]
49        return p_value > 0.0001
50
51    def _create_random_nd_tensor(self, dims, size_min, size_max):
52        size = [random.randint(size_min, size_max) for _ in range(dims)]
53        tensor = torch.zeros(size)
54        return tensor
55
56    def _random_float(self, a, b):
57        return (b - a) * random.random() + a
58
59    def test_calculate_gain_linear(self):
60        for fn in [
61            "linear",
62            "conv1d",
63            "conv2d",
64            "conv3d",
65            "conv_transpose2d",
66            "conv_transpose2d",
67            "conv_transpose3d",
68        ]:
69            gain = init.calculate_gain(fn)
70            self.assertEqual(gain, 1)
71
72    def test_calculate_gain_nonlinear(self):
73        for fn in ["sigmoid", "tanh", "relu", "leaky_relu"]:
74            gain = init.calculate_gain(fn)
75            if fn == "sigmoid":
76                self.assertEqual(gain, 1)
77            elif fn == "tanh":  # 5 / 3
78                self.assertEqual(gain, 1.6666666666666667)
79            elif fn == "relu":  # sqrt(2)
80                self.assertEqual(gain, 1.4142135623730951)
81            elif fn == "leaky_relu":  # sqrt(2 / 1 + slope^2))
82                self.assertEqual(gain, 1.4141428569978354)
83            elif fn == "selu":
84                self.assertEqual(gain, 0.75)
85
86    def test_calculate_gain_leaky_relu(self):
87        for param in [None, 0, 0.01, 10]:
88            gain = init.calculate_gain("leaky_relu", param)
89            if param is None:  # Default slope is 0.01
90                self.assertEqual(gain, 1.4141428569978354)
91            elif param == 0:  # No slope = same gain as normal ReLU
92                self.assertEqual(gain, 1.4142135623730951)
93            elif param == 0.01:
94                self.assertEqual(gain, 1.4141428569978354)
95            elif param == 10:
96                self.assertEqual(gain, 0.14071950894605836)
97
98    def test_calculate_gain_leaky_relu_only_accepts_numbers(self):
99        for param in [True, [1], {"a": "b"}]:
100            with self.assertRaises(ValueError):
101                init.calculate_gain("leaky_relu", param)
102
103    def test_calculate_gain_only_accepts_valid_nonlinearities(self):
104        for n in [2, 5, 25]:
105            # Generate random strings of lengths that definitely aren't supported
106            random_string = "".join(
107                [random.choice(string.ascii_lowercase) for i in range(n)]
108            )
109            with self.assertRaises(ValueError):
110                init.calculate_gain(random_string)
111
112    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
113    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
114    def test_uniform(self):
115        for dims in [1, 2, 4]:
116            input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50)
117            a = self._random_float(-3, 3)
118            b = a + self._random_float(1, 5)
119            init.uniform_(input_tensor, a=a, b=b)
120            assert self._is_uniform(input_tensor, a, b)
121
122    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
123    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
124    def test_normal(self):
125        for dims in [1, 2, 4]:
126            input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50)
127            mean = self._random_float(-3, 3)
128            std = self._random_float(1, 5)
129            init.normal_(input_tensor, mean=mean, std=std)
130
131            assert self._is_normal(input_tensor, mean, std)
132
133    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
134    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
135    def test_trunc_normal(self):
136        for dims in [1, 2, 4]:
137            input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50)
138            mean = self._random_float(-3, 3)
139            std = self._random_float(0.01, 1)
140            a = self._random_float(mean - 2 * std, mean)
141            b = self._random_float(mean, mean + 2 * std)
142            init.trunc_normal_(input_tensor, mean=mean, std=std, a=a, b=b)
143
144            assert self._is_trunc_normal(input_tensor, mean, std, a, b)
145
146    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
147    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
148    def test_trunc_normal_generator(self):
149        gen = torch.Generator()
150        gen.manual_seed(42)
151        input_tensor = torch.empty(5)
152        init.trunc_normal_(input_tensor, generator=gen)
153
154        ref = torch.empty(5)
155        torch.manual_seed(42)
156        init.trunc_normal_(ref)
157
158        self.assertEqual(input_tensor, ref)
159        assert self._is_trunc_normal(input_tensor, mean=0, std=1, a=0, b=1)
160
161    def test_constant(self):
162        for dims in [1, 2, 4]:
163            input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5)
164            val = self._random_float(1, 10)
165            init.constant_(input_tensor, val)
166
167            self.assertEqual(input_tensor, input_tensor.clone().fill_(val))
168
169    def test_ones_and_zeros(self):
170        for init_fn_, val in zip([init.ones_, init.zeros_], [1, 0]):
171            for dims in [1, 2, 4]:
172                input_tensor = self._create_random_nd_tensor(
173                    dims, size_min=1, size_max=5
174                )
175                init_fn_(input_tensor)
176
177                self.assertEqual(input_tensor, input_tensor.clone().fill_(val))
178
179    def test_eye(self):
180        input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5)
181        init.eye_(input_tensor)
182
183        # Check every single element
184        for i in range(input_tensor.size(0)):
185            for j in range(input_tensor.size(1)):
186                if i == j:
187                    assert input_tensor[i][j] == 1
188                else:
189                    assert input_tensor[i][j] == 0
190
191    def test_eye_only_works_on_2d_inputs(self):
192        for dims in [1, 3]:
193            with self.assertRaises(ValueError):
194                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3)
195                init.eye_(tensor)
196
197    def test_dirac_properties(self):
198        for dims in [3, 4, 5]:
199            for groups in [1, 2, 3]:
200                # prepare random tensor with random sizes, but fits groups
201                a, c, d, e = (random.randint(1, 5) for _ in range(4))
202                b = random.randint(
203                    1, 5 * groups
204                )  # same range as a*groups but all range allowed
205                # make sure first dim divides by groups
206                input_tensor = torch.randn((a * groups, b, c, d, e)[:dims])
207
208                init.dirac_(input_tensor, groups)
209
210                c_out, c_in = input_tensor.size(0) // groups, input_tensor.size(1)
211                min_d = min(c_out, c_in)
212                # Check number of nonzeros is equivalent to smallest dim (for each group)
213                assert torch.nonzero(input_tensor).size(0) == min_d * groups
214                # Check sum of values (can have precision issues, hence assertEqual) is also equivalent
215                self.assertEqual(input_tensor.sum(), min_d * groups)
216
217    def test_dirac_identity(self):
218        for groups in [1, 3]:
219            batch, in_c, out_c, size, kernel_size = (
220                8,
221                3,
222                9,
223                5,
224                3,
225            )  # in_c, out_c must divide by groups
226            eff_out_c = out_c // groups
227
228            # Test 1D
229            input_var = torch.randn(batch, in_c, size)
230            filter_var = torch.zeros(eff_out_c, in_c, kernel_size)
231            filter_var = torch.cat([filter_var] * groups)
232            init.dirac_(filter_var, groups)
233            output_var = F.conv1d(input_var, filter_var)
234            input_tensor, output_tensor = (
235                input_var.data,
236                output_var.data,
237            )  # Variables do not support nonzero
238            for g in range(groups):
239                # Assert in_c outputs are preserved (per each group)
240                self.assertEqual(
241                    input_tensor[:, :, 1:-1],
242                    output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :],
243                )
244                # Assert extra outputs are 0
245                assert (
246                    torch.nonzero(
247                        output_tensor[:, eff_out_c * g + in_c : eff_out_c * (g + 1), :]
248                    ).numel()
249                    == 0
250                )
251
252            # Test 2D
253            input_var = torch.randn(batch, in_c, size, size)
254            filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size)
255            filter_var = torch.cat([filter_var] * groups)
256            init.dirac_(filter_var, groups)
257            output_var = F.conv2d(input_var, filter_var)
258            input_tensor, output_tensor = (
259                input_var.data,
260                output_var.data,
261            )  # Variables do not support nonzero
262            for g in range(groups):
263                # Assert in_c outputs are preserved (per each group)
264                self.assertEqual(
265                    input_tensor[:, :, 1:-1, 1:-1],
266                    output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :],
267                )
268                # Assert extra outputs are 0
269                assert (
270                    torch.nonzero(
271                        output_tensor[
272                            :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, :
273                        ]
274                    ).numel()
275                    == 0
276                )
277
278            # Test 3D
279            input_var = torch.randn(batch, in_c, size, size, size)
280            filter_var = torch.zeros(
281                eff_out_c, in_c, kernel_size, kernel_size, kernel_size
282            )
283            filter_var = torch.cat([filter_var] * groups)
284            init.dirac_(filter_var, groups)
285            output_var = F.conv3d(input_var, filter_var)
286            input_tensor, output_tensor = input_var.data, output_var.data
287            for g in range(groups):
288                # Assert in_c outputs are preserved (per each group)
289                self.assertEqual(
290                    input_tensor[:, :, 1:-1, 1:-1, 1:-1],
291                    output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :, :],
292                )
293                # Assert extra outputs are 0
294                assert (
295                    torch.nonzero(
296                        output_tensor[
297                            :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, :, :
298                        ]
299                    ).numel()
300                    == 0
301                )
302
303    def test_dirac_only_works_on_3_4_5d_inputs(self):
304        for dims in [1, 2, 6]:
305            with self.assertRaises(ValueError):
306                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3)
307                init.dirac_(tensor)
308
309    def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self):
310        for dims in [0, 1]:
311            tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
312            with self.assertRaises(ValueError):
313                init.xavier_uniform_(tensor)
314
315    def test_xavier_normal_errors_on_inputs_smaller_than_2d(self):
316        for dims in [0, 1]:
317            tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
318            with self.assertRaises(ValueError):
319                init.xavier_normal_(tensor)
320
321    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
322    @slowTest
323    def test_xavier_uniform(self):
324        for use_gain in [True, False]:
325            for dims in [2, 4]:
326                input_tensor = self._create_random_nd_tensor(
327                    dims, size_min=20, size_max=25
328                )
329                gain = 1
330
331                if use_gain:
332                    gain = self._random_float(0.1, 2)
333                    init.xavier_uniform_(input_tensor, gain=gain)
334                else:
335                    init.xavier_uniform_(input_tensor)
336
337                fan_in = input_tensor.size(1)
338                fan_out = input_tensor.size(0)
339                if input_tensor.dim() > 2:
340                    fan_in *= input_tensor[0, 0].numel()
341                    fan_out *= input_tensor[0, 0].numel()
342
343                expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out))
344                bounds = expected_std * math.sqrt(3)
345                assert self._is_uniform(input_tensor, -bounds, bounds)
346
347    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
348    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
349    def test_xavier_normal(self):
350        for use_gain in [True, False]:
351            for dims in [2, 4]:
352                input_tensor = self._create_random_nd_tensor(
353                    dims, size_min=20, size_max=25
354                )
355                gain = 1
356
357                if use_gain:
358                    gain = self._random_float(0.1, 2)
359                    init.xavier_normal_(input_tensor, gain=gain)
360                else:
361                    init.xavier_normal_(input_tensor)
362
363                fan_in = input_tensor.size(1)
364                fan_out = input_tensor.size(0)
365                if input_tensor.dim() > 2:
366                    fan_in *= input_tensor[0, 0].numel()
367                    fan_out *= input_tensor[0, 0].numel()
368
369                expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out))
370                assert self._is_normal(input_tensor, 0, expected_std)
371
372    def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self):
373        for dims in [0, 1]:
374            with self.assertRaises(ValueError):
375                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
376                init.kaiming_uniform_(tensor)
377
378    def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self):
379        for dims in [0, 1]:
380            with self.assertRaises(ValueError):
381                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
382                init.kaiming_normal_(tensor)
383
384    def test_kaiming_uniform_warning_on_0element_tensor(self):
385        tensor = torch.empty(0, 1)
386        with self.assertWarnsRegex(
387            UserWarning, "Initializing zero-element tensors is a no-op"
388        ):
389            _ = init.kaiming_uniform_(tensor)
390
391    def test_kaiming_normal_warning_on_0element_tensor(self):
392        tensor = torch.empty(0, 1)
393        with self.assertWarnsRegex(
394            UserWarning, "Initializing zero-element tensors is a no-op"
395        ):
396            _ = init.kaiming_normal_(tensor)
397
398    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
399    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
400    def test_kaiming_uniform(self):
401        for use_a in [True, False]:
402            for dims in [2, 4]:
403                for mode in ["fan_in", "fan_out"]:
404                    input_tensor = self._create_random_nd_tensor(
405                        dims, size_min=20, size_max=25
406                    )
407                    if use_a:
408                        a = self._random_float(0.1, 2)
409                        init.kaiming_uniform_(input_tensor, a=a, mode=mode)
410                    else:
411                        a = 0
412                        init.kaiming_uniform_(input_tensor, mode=mode)
413
414                    fan_in = input_tensor.size(1)
415                    fan_out = input_tensor.size(0)
416                    if input_tensor.dim() > 2:
417                        fan_in *= input_tensor[0, 0].numel()
418                        fan_out *= input_tensor[0, 0].numel()
419
420                    if mode == "fan_in":
421                        n = fan_in
422                    else:
423                        n = fan_out
424
425                    expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
426                    bounds = expected_std * math.sqrt(3.0)
427                    assert self._is_uniform(input_tensor, -bounds, bounds)
428
429    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
430    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
431    def test_kaiming_normal(self):
432        for use_a in [True, False]:
433            for dims in [2, 4]:
434                for mode in ["fan_in", "fan_out"]:
435                    input_tensor = self._create_random_nd_tensor(
436                        dims, size_min=20, size_max=25
437                    )
438                    if use_a:
439                        a = self._random_float(0.1, 2)
440                        init.kaiming_normal_(input_tensor, a=a, mode=mode)
441                    else:
442                        a = 0
443                        init.kaiming_normal_(input_tensor, mode=mode)
444
445                    fan_in = input_tensor.size(1)
446                    fan_out = input_tensor.size(0)
447                    if input_tensor.dim() > 2:
448                        fan_in *= input_tensor[0, 0].numel()
449                        fan_out *= input_tensor[0, 0].numel()
450
451                    if mode == "fan_in":
452                        n = fan_in
453                    else:
454                        n = fan_out
455
456                    expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
457                    assert self._is_normal(input_tensor, 0, expected_std)
458
459    def test_sparse_only_works_on_2d_inputs(self):
460        for dims in [1, 3]:
461            with self.assertRaises(ValueError):
462                sparsity = self._random_float(0.1, 0.9)
463                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3)
464                init.sparse_(tensor, sparsity)
465
466    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
467    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
468    def test_sparse_default_std(self):
469        for use_random_std in [True, False]:
470            input_tensor = self._create_random_nd_tensor(2, size_min=30, size_max=35)
471            rows, cols = input_tensor.size(0), input_tensor.size(1)
472            sparsity = self._random_float(0.1, 0.2)
473
474            std = 0.01  # default std
475            if use_random_std:
476                std = self._random_float(0.01, 0.2)
477                init.sparse_(input_tensor, sparsity=sparsity, std=std)
478            else:
479                init.sparse_(input_tensor, sparsity=sparsity)
480
481            for col_idx in range(input_tensor.size(1)):
482                column = input_tensor[:, col_idx]
483                assert column[column == 0].nelement() >= math.ceil(sparsity * rows)
484
485            assert self._is_normal(input_tensor[input_tensor != 0], 0, std)
486
487    @skipIfNoLapack
488    def test_orthogonal(self):
489        for use_gain in [True, False]:
490            for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]:
491                input_tensor = torch.zeros(tensor_size)
492                gain = 1.0
493
494                if use_gain:
495                    gain = self._random_float(0.1, 2)
496                    init.orthogonal_(input_tensor, gain=gain)
497                else:
498                    init.orthogonal_(input_tensor)
499
500                rows, cols = tensor_size[0], reduce(mul, tensor_size[1:])
501                flattened_tensor = input_tensor.view(rows, cols)
502                if rows > cols:
503                    self.assertEqual(
504                        torch.mm(flattened_tensor.t(), flattened_tensor),
505                        torch.eye(cols) * gain**2,
506                        atol=1e-6,
507                        rtol=0,
508                    )
509                else:
510                    self.assertEqual(
511                        torch.mm(flattened_tensor, flattened_tensor.t()),
512                        torch.eye(rows) * gain**2,
513                        atol=1e-6,
514                        rtol=0,
515                    )
516
517    def test_deprecation(self):
518        x = torch.randn(3, 3)
519
520        def fn():
521            init.normal(x)
522
523        with self.assertWarnsRegex(
524            FutureWarning,
525            "deprecated",
526            msg="methods not suffixed with underscore should be deprecated",
527        ):
528            fn()
529
530
531if __name__ == "__main__":
532    run_tests()
533