xref: /aosp_15_r20/external/pytorch/test/nn/test_init.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Workerimport math
3*da0073e9SAndroid Build Coastguard Workerimport random
4*da0073e9SAndroid Build Coastguard Workerimport string
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce
7*da0073e9SAndroid Build Coastguard Workerfrom operator import mul
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.init as init
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
13*da0073e9SAndroid Build Coastguard Worker    run_tests,
14*da0073e9SAndroid Build Coastguard Worker    skipIfNoLapack,
15*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
16*da0073e9SAndroid Build Coastguard Worker    slowTest,
17*da0073e9SAndroid Build Coastguard Worker    TEST_SCIPY,
18*da0073e9SAndroid Build Coastguard Worker    TestCase,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
23*da0073e9SAndroid Build Coastguard Worker    from scipy import stats
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerclass TestNNInit(TestCase):
27*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
28*da0073e9SAndroid Build Coastguard Worker        super().setUp()
29*da0073e9SAndroid Build Coastguard Worker        random.seed(123)
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    def _is_normal(self, tensor, mean, std):
32*da0073e9SAndroid Build Coastguard Worker        samples = tensor.view(-1).tolist()
33*da0073e9SAndroid Build Coastguard Worker        p_value = stats.kstest(samples, "norm", args=(mean, std))[1]
34*da0073e9SAndroid Build Coastguard Worker        return p_value > 0.0001
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def _is_trunc_normal(self, tensor, mean, std, a, b):
37*da0073e9SAndroid Build Coastguard Worker        # scipy's trunc norm is suited for data drawn from N(0, 1),
38*da0073e9SAndroid Build Coastguard Worker        # so we need to transform our data to test it using scipy.
39*da0073e9SAndroid Build Coastguard Worker        z_samples = (tensor.view(-1) - mean) / std
40*da0073e9SAndroid Build Coastguard Worker        z_samples = z_samples.tolist()
41*da0073e9SAndroid Build Coastguard Worker        a0 = (a - mean) / std
42*da0073e9SAndroid Build Coastguard Worker        b0 = (b - mean) / std
43*da0073e9SAndroid Build Coastguard Worker        p_value = stats.kstest(z_samples, "truncnorm", args=(a0, b0))[1]
44*da0073e9SAndroid Build Coastguard Worker        return p_value > 0.0001
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def _is_uniform(self, tensor, a, b):
47*da0073e9SAndroid Build Coastguard Worker        samples = tensor.view(-1).tolist()
48*da0073e9SAndroid Build Coastguard Worker        p_value = stats.kstest(samples, "uniform", args=(a, (b - a)))[1]
49*da0073e9SAndroid Build Coastguard Worker        return p_value > 0.0001
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def _create_random_nd_tensor(self, dims, size_min, size_max):
52*da0073e9SAndroid Build Coastguard Worker        size = [random.randint(size_min, size_max) for _ in range(dims)]
53*da0073e9SAndroid Build Coastguard Worker        tensor = torch.zeros(size)
54*da0073e9SAndroid Build Coastguard Worker        return tensor
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def _random_float(self, a, b):
57*da0073e9SAndroid Build Coastguard Worker        return (b - a) * random.random() + a
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def test_calculate_gain_linear(self):
60*da0073e9SAndroid Build Coastguard Worker        for fn in [
61*da0073e9SAndroid Build Coastguard Worker            "linear",
62*da0073e9SAndroid Build Coastguard Worker            "conv1d",
63*da0073e9SAndroid Build Coastguard Worker            "conv2d",
64*da0073e9SAndroid Build Coastguard Worker            "conv3d",
65*da0073e9SAndroid Build Coastguard Worker            "conv_transpose2d",
66*da0073e9SAndroid Build Coastguard Worker            "conv_transpose2d",
67*da0073e9SAndroid Build Coastguard Worker            "conv_transpose3d",
68*da0073e9SAndroid Build Coastguard Worker        ]:
69*da0073e9SAndroid Build Coastguard Worker            gain = init.calculate_gain(fn)
70*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gain, 1)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def test_calculate_gain_nonlinear(self):
73*da0073e9SAndroid Build Coastguard Worker        for fn in ["sigmoid", "tanh", "relu", "leaky_relu"]:
74*da0073e9SAndroid Build Coastguard Worker            gain = init.calculate_gain(fn)
75*da0073e9SAndroid Build Coastguard Worker            if fn == "sigmoid":
76*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 1)
77*da0073e9SAndroid Build Coastguard Worker            elif fn == "tanh":  # 5 / 3
78*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 1.6666666666666667)
79*da0073e9SAndroid Build Coastguard Worker            elif fn == "relu":  # sqrt(2)
80*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 1.4142135623730951)
81*da0073e9SAndroid Build Coastguard Worker            elif fn == "leaky_relu":  # sqrt(2 / 1 + slope^2))
82*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 1.4141428569978354)
83*da0073e9SAndroid Build Coastguard Worker            elif fn == "selu":
84*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 0.75)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def test_calculate_gain_leaky_relu(self):
87*da0073e9SAndroid Build Coastguard Worker        for param in [None, 0, 0.01, 10]:
88*da0073e9SAndroid Build Coastguard Worker            gain = init.calculate_gain("leaky_relu", param)
89*da0073e9SAndroid Build Coastguard Worker            if param is None:  # Default slope is 0.01
90*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 1.4141428569978354)
91*da0073e9SAndroid Build Coastguard Worker            elif param == 0:  # No slope = same gain as normal ReLU
92*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 1.4142135623730951)
93*da0073e9SAndroid Build Coastguard Worker            elif param == 0.01:
94*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 1.4141428569978354)
95*da0073e9SAndroid Build Coastguard Worker            elif param == 10:
96*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(gain, 0.14071950894605836)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    def test_calculate_gain_leaky_relu_only_accepts_numbers(self):
99*da0073e9SAndroid Build Coastguard Worker        for param in [True, [1], {"a": "b"}]:
100*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
101*da0073e9SAndroid Build Coastguard Worker                init.calculate_gain("leaky_relu", param)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    def test_calculate_gain_only_accepts_valid_nonlinearities(self):
104*da0073e9SAndroid Build Coastguard Worker        for n in [2, 5, 25]:
105*da0073e9SAndroid Build Coastguard Worker            # Generate random strings of lengths that definitely aren't supported
106*da0073e9SAndroid Build Coastguard Worker            random_string = "".join(
107*da0073e9SAndroid Build Coastguard Worker                [random.choice(string.ascii_lowercase) for i in range(n)]
108*da0073e9SAndroid Build Coastguard Worker            )
109*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
110*da0073e9SAndroid Build Coastguard Worker                init.calculate_gain(random_string)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
113*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
114*da0073e9SAndroid Build Coastguard Worker    def test_uniform(self):
115*da0073e9SAndroid Build Coastguard Worker        for dims in [1, 2, 4]:
116*da0073e9SAndroid Build Coastguard Worker            input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50)
117*da0073e9SAndroid Build Coastguard Worker            a = self._random_float(-3, 3)
118*da0073e9SAndroid Build Coastguard Worker            b = a + self._random_float(1, 5)
119*da0073e9SAndroid Build Coastguard Worker            init.uniform_(input_tensor, a=a, b=b)
120*da0073e9SAndroid Build Coastguard Worker            assert self._is_uniform(input_tensor, a, b)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
123*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
124*da0073e9SAndroid Build Coastguard Worker    def test_normal(self):
125*da0073e9SAndroid Build Coastguard Worker        for dims in [1, 2, 4]:
126*da0073e9SAndroid Build Coastguard Worker            input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50)
127*da0073e9SAndroid Build Coastguard Worker            mean = self._random_float(-3, 3)
128*da0073e9SAndroid Build Coastguard Worker            std = self._random_float(1, 5)
129*da0073e9SAndroid Build Coastguard Worker            init.normal_(input_tensor, mean=mean, std=std)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker            assert self._is_normal(input_tensor, mean, std)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
134*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
135*da0073e9SAndroid Build Coastguard Worker    def test_trunc_normal(self):
136*da0073e9SAndroid Build Coastguard Worker        for dims in [1, 2, 4]:
137*da0073e9SAndroid Build Coastguard Worker            input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50)
138*da0073e9SAndroid Build Coastguard Worker            mean = self._random_float(-3, 3)
139*da0073e9SAndroid Build Coastguard Worker            std = self._random_float(0.01, 1)
140*da0073e9SAndroid Build Coastguard Worker            a = self._random_float(mean - 2 * std, mean)
141*da0073e9SAndroid Build Coastguard Worker            b = self._random_float(mean, mean + 2 * std)
142*da0073e9SAndroid Build Coastguard Worker            init.trunc_normal_(input_tensor, mean=mean, std=std, a=a, b=b)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker            assert self._is_trunc_normal(input_tensor, mean, std, a, b)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
147*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
148*da0073e9SAndroid Build Coastguard Worker    def test_trunc_normal_generator(self):
149*da0073e9SAndroid Build Coastguard Worker        gen = torch.Generator()
150*da0073e9SAndroid Build Coastguard Worker        gen.manual_seed(42)
151*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.empty(5)
152*da0073e9SAndroid Build Coastguard Worker        init.trunc_normal_(input_tensor, generator=gen)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        ref = torch.empty(5)
155*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
156*da0073e9SAndroid Build Coastguard Worker        init.trunc_normal_(ref)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input_tensor, ref)
159*da0073e9SAndroid Build Coastguard Worker        assert self._is_trunc_normal(input_tensor, mean=0, std=1, a=0, b=1)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def test_constant(self):
162*da0073e9SAndroid Build Coastguard Worker        for dims in [1, 2, 4]:
163*da0073e9SAndroid Build Coastguard Worker            input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5)
164*da0073e9SAndroid Build Coastguard Worker            val = self._random_float(1, 10)
165*da0073e9SAndroid Build Coastguard Worker            init.constant_(input_tensor, val)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_tensor, input_tensor.clone().fill_(val))
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    def test_ones_and_zeros(self):
170*da0073e9SAndroid Build Coastguard Worker        for init_fn_, val in zip([init.ones_, init.zeros_], [1, 0]):
171*da0073e9SAndroid Build Coastguard Worker            for dims in [1, 2, 4]:
172*da0073e9SAndroid Build Coastguard Worker                input_tensor = self._create_random_nd_tensor(
173*da0073e9SAndroid Build Coastguard Worker                    dims, size_min=1, size_max=5
174*da0073e9SAndroid Build Coastguard Worker                )
175*da0073e9SAndroid Build Coastguard Worker                init_fn_(input_tensor)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input_tensor, input_tensor.clone().fill_(val))
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    def test_eye(self):
180*da0073e9SAndroid Build Coastguard Worker        input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5)
181*da0073e9SAndroid Build Coastguard Worker        init.eye_(input_tensor)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        # Check every single element
184*da0073e9SAndroid Build Coastguard Worker        for i in range(input_tensor.size(0)):
185*da0073e9SAndroid Build Coastguard Worker            for j in range(input_tensor.size(1)):
186*da0073e9SAndroid Build Coastguard Worker                if i == j:
187*da0073e9SAndroid Build Coastguard Worker                    assert input_tensor[i][j] == 1
188*da0073e9SAndroid Build Coastguard Worker                else:
189*da0073e9SAndroid Build Coastguard Worker                    assert input_tensor[i][j] == 0
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    def test_eye_only_works_on_2d_inputs(self):
192*da0073e9SAndroid Build Coastguard Worker        for dims in [1, 3]:
193*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
194*da0073e9SAndroid Build Coastguard Worker                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3)
195*da0073e9SAndroid Build Coastguard Worker                init.eye_(tensor)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def test_dirac_properties(self):
198*da0073e9SAndroid Build Coastguard Worker        for dims in [3, 4, 5]:
199*da0073e9SAndroid Build Coastguard Worker            for groups in [1, 2, 3]:
200*da0073e9SAndroid Build Coastguard Worker                # prepare random tensor with random sizes, but fits groups
201*da0073e9SAndroid Build Coastguard Worker                a, c, d, e = (random.randint(1, 5) for _ in range(4))
202*da0073e9SAndroid Build Coastguard Worker                b = random.randint(
203*da0073e9SAndroid Build Coastguard Worker                    1, 5 * groups
204*da0073e9SAndroid Build Coastguard Worker                )  # same range as a*groups but all range allowed
205*da0073e9SAndroid Build Coastguard Worker                # make sure first dim divides by groups
206*da0073e9SAndroid Build Coastguard Worker                input_tensor = torch.randn((a * groups, b, c, d, e)[:dims])
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker                init.dirac_(input_tensor, groups)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker                c_out, c_in = input_tensor.size(0) // groups, input_tensor.size(1)
211*da0073e9SAndroid Build Coastguard Worker                min_d = min(c_out, c_in)
212*da0073e9SAndroid Build Coastguard Worker                # Check number of nonzeros is equivalent to smallest dim (for each group)
213*da0073e9SAndroid Build Coastguard Worker                assert torch.nonzero(input_tensor).size(0) == min_d * groups
214*da0073e9SAndroid Build Coastguard Worker                # Check sum of values (can have precision issues, hence assertEqual) is also equivalent
215*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input_tensor.sum(), min_d * groups)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    def test_dirac_identity(self):
218*da0073e9SAndroid Build Coastguard Worker        for groups in [1, 3]:
219*da0073e9SAndroid Build Coastguard Worker            batch, in_c, out_c, size, kernel_size = (
220*da0073e9SAndroid Build Coastguard Worker                8,
221*da0073e9SAndroid Build Coastguard Worker                3,
222*da0073e9SAndroid Build Coastguard Worker                9,
223*da0073e9SAndroid Build Coastguard Worker                5,
224*da0073e9SAndroid Build Coastguard Worker                3,
225*da0073e9SAndroid Build Coastguard Worker            )  # in_c, out_c must divide by groups
226*da0073e9SAndroid Build Coastguard Worker            eff_out_c = out_c // groups
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker            # Test 1D
229*da0073e9SAndroid Build Coastguard Worker            input_var = torch.randn(batch, in_c, size)
230*da0073e9SAndroid Build Coastguard Worker            filter_var = torch.zeros(eff_out_c, in_c, kernel_size)
231*da0073e9SAndroid Build Coastguard Worker            filter_var = torch.cat([filter_var] * groups)
232*da0073e9SAndroid Build Coastguard Worker            init.dirac_(filter_var, groups)
233*da0073e9SAndroid Build Coastguard Worker            output_var = F.conv1d(input_var, filter_var)
234*da0073e9SAndroid Build Coastguard Worker            input_tensor, output_tensor = (
235*da0073e9SAndroid Build Coastguard Worker                input_var.data,
236*da0073e9SAndroid Build Coastguard Worker                output_var.data,
237*da0073e9SAndroid Build Coastguard Worker            )  # Variables do not support nonzero
238*da0073e9SAndroid Build Coastguard Worker            for g in range(groups):
239*da0073e9SAndroid Build Coastguard Worker                # Assert in_c outputs are preserved (per each group)
240*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
241*da0073e9SAndroid Build Coastguard Worker                    input_tensor[:, :, 1:-1],
242*da0073e9SAndroid Build Coastguard Worker                    output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :],
243*da0073e9SAndroid Build Coastguard Worker                )
244*da0073e9SAndroid Build Coastguard Worker                # Assert extra outputs are 0
245*da0073e9SAndroid Build Coastguard Worker                assert (
246*da0073e9SAndroid Build Coastguard Worker                    torch.nonzero(
247*da0073e9SAndroid Build Coastguard Worker                        output_tensor[:, eff_out_c * g + in_c : eff_out_c * (g + 1), :]
248*da0073e9SAndroid Build Coastguard Worker                    ).numel()
249*da0073e9SAndroid Build Coastguard Worker                    == 0
250*da0073e9SAndroid Build Coastguard Worker                )
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker            # Test 2D
253*da0073e9SAndroid Build Coastguard Worker            input_var = torch.randn(batch, in_c, size, size)
254*da0073e9SAndroid Build Coastguard Worker            filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size)
255*da0073e9SAndroid Build Coastguard Worker            filter_var = torch.cat([filter_var] * groups)
256*da0073e9SAndroid Build Coastguard Worker            init.dirac_(filter_var, groups)
257*da0073e9SAndroid Build Coastguard Worker            output_var = F.conv2d(input_var, filter_var)
258*da0073e9SAndroid Build Coastguard Worker            input_tensor, output_tensor = (
259*da0073e9SAndroid Build Coastguard Worker                input_var.data,
260*da0073e9SAndroid Build Coastguard Worker                output_var.data,
261*da0073e9SAndroid Build Coastguard Worker            )  # Variables do not support nonzero
262*da0073e9SAndroid Build Coastguard Worker            for g in range(groups):
263*da0073e9SAndroid Build Coastguard Worker                # Assert in_c outputs are preserved (per each group)
264*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
265*da0073e9SAndroid Build Coastguard Worker                    input_tensor[:, :, 1:-1, 1:-1],
266*da0073e9SAndroid Build Coastguard Worker                    output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :],
267*da0073e9SAndroid Build Coastguard Worker                )
268*da0073e9SAndroid Build Coastguard Worker                # Assert extra outputs are 0
269*da0073e9SAndroid Build Coastguard Worker                assert (
270*da0073e9SAndroid Build Coastguard Worker                    torch.nonzero(
271*da0073e9SAndroid Build Coastguard Worker                        output_tensor[
272*da0073e9SAndroid Build Coastguard Worker                            :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, :
273*da0073e9SAndroid Build Coastguard Worker                        ]
274*da0073e9SAndroid Build Coastguard Worker                    ).numel()
275*da0073e9SAndroid Build Coastguard Worker                    == 0
276*da0073e9SAndroid Build Coastguard Worker                )
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker            # Test 3D
279*da0073e9SAndroid Build Coastguard Worker            input_var = torch.randn(batch, in_c, size, size, size)
280*da0073e9SAndroid Build Coastguard Worker            filter_var = torch.zeros(
281*da0073e9SAndroid Build Coastguard Worker                eff_out_c, in_c, kernel_size, kernel_size, kernel_size
282*da0073e9SAndroid Build Coastguard Worker            )
283*da0073e9SAndroid Build Coastguard Worker            filter_var = torch.cat([filter_var] * groups)
284*da0073e9SAndroid Build Coastguard Worker            init.dirac_(filter_var, groups)
285*da0073e9SAndroid Build Coastguard Worker            output_var = F.conv3d(input_var, filter_var)
286*da0073e9SAndroid Build Coastguard Worker            input_tensor, output_tensor = input_var.data, output_var.data
287*da0073e9SAndroid Build Coastguard Worker            for g in range(groups):
288*da0073e9SAndroid Build Coastguard Worker                # Assert in_c outputs are preserved (per each group)
289*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
290*da0073e9SAndroid Build Coastguard Worker                    input_tensor[:, :, 1:-1, 1:-1, 1:-1],
291*da0073e9SAndroid Build Coastguard Worker                    output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :, :],
292*da0073e9SAndroid Build Coastguard Worker                )
293*da0073e9SAndroid Build Coastguard Worker                # Assert extra outputs are 0
294*da0073e9SAndroid Build Coastguard Worker                assert (
295*da0073e9SAndroid Build Coastguard Worker                    torch.nonzero(
296*da0073e9SAndroid Build Coastguard Worker                        output_tensor[
297*da0073e9SAndroid Build Coastguard Worker                            :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, :, :
298*da0073e9SAndroid Build Coastguard Worker                        ]
299*da0073e9SAndroid Build Coastguard Worker                    ).numel()
300*da0073e9SAndroid Build Coastguard Worker                    == 0
301*da0073e9SAndroid Build Coastguard Worker                )
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    def test_dirac_only_works_on_3_4_5d_inputs(self):
304*da0073e9SAndroid Build Coastguard Worker        for dims in [1, 2, 6]:
305*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
306*da0073e9SAndroid Build Coastguard Worker                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3)
307*da0073e9SAndroid Build Coastguard Worker                init.dirac_(tensor)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self):
310*da0073e9SAndroid Build Coastguard Worker        for dims in [0, 1]:
311*da0073e9SAndroid Build Coastguard Worker            tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
312*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
313*da0073e9SAndroid Build Coastguard Worker                init.xavier_uniform_(tensor)
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker    def test_xavier_normal_errors_on_inputs_smaller_than_2d(self):
316*da0073e9SAndroid Build Coastguard Worker        for dims in [0, 1]:
317*da0073e9SAndroid Build Coastguard Worker            tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
318*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
319*da0073e9SAndroid Build Coastguard Worker                init.xavier_normal_(tensor)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
322*da0073e9SAndroid Build Coastguard Worker    @slowTest
323*da0073e9SAndroid Build Coastguard Worker    def test_xavier_uniform(self):
324*da0073e9SAndroid Build Coastguard Worker        for use_gain in [True, False]:
325*da0073e9SAndroid Build Coastguard Worker            for dims in [2, 4]:
326*da0073e9SAndroid Build Coastguard Worker                input_tensor = self._create_random_nd_tensor(
327*da0073e9SAndroid Build Coastguard Worker                    dims, size_min=20, size_max=25
328*da0073e9SAndroid Build Coastguard Worker                )
329*da0073e9SAndroid Build Coastguard Worker                gain = 1
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker                if use_gain:
332*da0073e9SAndroid Build Coastguard Worker                    gain = self._random_float(0.1, 2)
333*da0073e9SAndroid Build Coastguard Worker                    init.xavier_uniform_(input_tensor, gain=gain)
334*da0073e9SAndroid Build Coastguard Worker                else:
335*da0073e9SAndroid Build Coastguard Worker                    init.xavier_uniform_(input_tensor)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker                fan_in = input_tensor.size(1)
338*da0073e9SAndroid Build Coastguard Worker                fan_out = input_tensor.size(0)
339*da0073e9SAndroid Build Coastguard Worker                if input_tensor.dim() > 2:
340*da0073e9SAndroid Build Coastguard Worker                    fan_in *= input_tensor[0, 0].numel()
341*da0073e9SAndroid Build Coastguard Worker                    fan_out *= input_tensor[0, 0].numel()
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker                expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out))
344*da0073e9SAndroid Build Coastguard Worker                bounds = expected_std * math.sqrt(3)
345*da0073e9SAndroid Build Coastguard Worker                assert self._is_uniform(input_tensor, -bounds, bounds)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
348*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
349*da0073e9SAndroid Build Coastguard Worker    def test_xavier_normal(self):
350*da0073e9SAndroid Build Coastguard Worker        for use_gain in [True, False]:
351*da0073e9SAndroid Build Coastguard Worker            for dims in [2, 4]:
352*da0073e9SAndroid Build Coastguard Worker                input_tensor = self._create_random_nd_tensor(
353*da0073e9SAndroid Build Coastguard Worker                    dims, size_min=20, size_max=25
354*da0073e9SAndroid Build Coastguard Worker                )
355*da0073e9SAndroid Build Coastguard Worker                gain = 1
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker                if use_gain:
358*da0073e9SAndroid Build Coastguard Worker                    gain = self._random_float(0.1, 2)
359*da0073e9SAndroid Build Coastguard Worker                    init.xavier_normal_(input_tensor, gain=gain)
360*da0073e9SAndroid Build Coastguard Worker                else:
361*da0073e9SAndroid Build Coastguard Worker                    init.xavier_normal_(input_tensor)
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker                fan_in = input_tensor.size(1)
364*da0073e9SAndroid Build Coastguard Worker                fan_out = input_tensor.size(0)
365*da0073e9SAndroid Build Coastguard Worker                if input_tensor.dim() > 2:
366*da0073e9SAndroid Build Coastguard Worker                    fan_in *= input_tensor[0, 0].numel()
367*da0073e9SAndroid Build Coastguard Worker                    fan_out *= input_tensor[0, 0].numel()
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker                expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out))
370*da0073e9SAndroid Build Coastguard Worker                assert self._is_normal(input_tensor, 0, expected_std)
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker    def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self):
373*da0073e9SAndroid Build Coastguard Worker        for dims in [0, 1]:
374*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
375*da0073e9SAndroid Build Coastguard Worker                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
376*da0073e9SAndroid Build Coastguard Worker                init.kaiming_uniform_(tensor)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self):
379*da0073e9SAndroid Build Coastguard Worker        for dims in [0, 1]:
380*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
381*da0073e9SAndroid Build Coastguard Worker                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1)
382*da0073e9SAndroid Build Coastguard Worker                init.kaiming_normal_(tensor)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker    def test_kaiming_uniform_warning_on_0element_tensor(self):
385*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(0, 1)
386*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
387*da0073e9SAndroid Build Coastguard Worker            UserWarning, "Initializing zero-element tensors is a no-op"
388*da0073e9SAndroid Build Coastguard Worker        ):
389*da0073e9SAndroid Build Coastguard Worker            _ = init.kaiming_uniform_(tensor)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    def test_kaiming_normal_warning_on_0element_tensor(self):
392*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(0, 1)
393*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
394*da0073e9SAndroid Build Coastguard Worker            UserWarning, "Initializing zero-element tensors is a no-op"
395*da0073e9SAndroid Build Coastguard Worker        ):
396*da0073e9SAndroid Build Coastguard Worker            _ = init.kaiming_normal_(tensor)
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
399*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
400*da0073e9SAndroid Build Coastguard Worker    def test_kaiming_uniform(self):
401*da0073e9SAndroid Build Coastguard Worker        for use_a in [True, False]:
402*da0073e9SAndroid Build Coastguard Worker            for dims in [2, 4]:
403*da0073e9SAndroid Build Coastguard Worker                for mode in ["fan_in", "fan_out"]:
404*da0073e9SAndroid Build Coastguard Worker                    input_tensor = self._create_random_nd_tensor(
405*da0073e9SAndroid Build Coastguard Worker                        dims, size_min=20, size_max=25
406*da0073e9SAndroid Build Coastguard Worker                    )
407*da0073e9SAndroid Build Coastguard Worker                    if use_a:
408*da0073e9SAndroid Build Coastguard Worker                        a = self._random_float(0.1, 2)
409*da0073e9SAndroid Build Coastguard Worker                        init.kaiming_uniform_(input_tensor, a=a, mode=mode)
410*da0073e9SAndroid Build Coastguard Worker                    else:
411*da0073e9SAndroid Build Coastguard Worker                        a = 0
412*da0073e9SAndroid Build Coastguard Worker                        init.kaiming_uniform_(input_tensor, mode=mode)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker                    fan_in = input_tensor.size(1)
415*da0073e9SAndroid Build Coastguard Worker                    fan_out = input_tensor.size(0)
416*da0073e9SAndroid Build Coastguard Worker                    if input_tensor.dim() > 2:
417*da0073e9SAndroid Build Coastguard Worker                        fan_in *= input_tensor[0, 0].numel()
418*da0073e9SAndroid Build Coastguard Worker                        fan_out *= input_tensor[0, 0].numel()
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker                    if mode == "fan_in":
421*da0073e9SAndroid Build Coastguard Worker                        n = fan_in
422*da0073e9SAndroid Build Coastguard Worker                    else:
423*da0073e9SAndroid Build Coastguard Worker                        n = fan_out
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker                    expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
426*da0073e9SAndroid Build Coastguard Worker                    bounds = expected_std * math.sqrt(3.0)
427*da0073e9SAndroid Build Coastguard Worker                    assert self._is_uniform(input_tensor, -bounds, bounds)
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
430*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
431*da0073e9SAndroid Build Coastguard Worker    def test_kaiming_normal(self):
432*da0073e9SAndroid Build Coastguard Worker        for use_a in [True, False]:
433*da0073e9SAndroid Build Coastguard Worker            for dims in [2, 4]:
434*da0073e9SAndroid Build Coastguard Worker                for mode in ["fan_in", "fan_out"]:
435*da0073e9SAndroid Build Coastguard Worker                    input_tensor = self._create_random_nd_tensor(
436*da0073e9SAndroid Build Coastguard Worker                        dims, size_min=20, size_max=25
437*da0073e9SAndroid Build Coastguard Worker                    )
438*da0073e9SAndroid Build Coastguard Worker                    if use_a:
439*da0073e9SAndroid Build Coastguard Worker                        a = self._random_float(0.1, 2)
440*da0073e9SAndroid Build Coastguard Worker                        init.kaiming_normal_(input_tensor, a=a, mode=mode)
441*da0073e9SAndroid Build Coastguard Worker                    else:
442*da0073e9SAndroid Build Coastguard Worker                        a = 0
443*da0073e9SAndroid Build Coastguard Worker                        init.kaiming_normal_(input_tensor, mode=mode)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker                    fan_in = input_tensor.size(1)
446*da0073e9SAndroid Build Coastguard Worker                    fan_out = input_tensor.size(0)
447*da0073e9SAndroid Build Coastguard Worker                    if input_tensor.dim() > 2:
448*da0073e9SAndroid Build Coastguard Worker                        fan_in *= input_tensor[0, 0].numel()
449*da0073e9SAndroid Build Coastguard Worker                        fan_out *= input_tensor[0, 0].numel()
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker                    if mode == "fan_in":
452*da0073e9SAndroid Build Coastguard Worker                        n = fan_in
453*da0073e9SAndroid Build Coastguard Worker                    else:
454*da0073e9SAndroid Build Coastguard Worker                        n = fan_out
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker                    expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
457*da0073e9SAndroid Build Coastguard Worker                    assert self._is_normal(input_tensor, 0, expected_std)
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker    def test_sparse_only_works_on_2d_inputs(self):
460*da0073e9SAndroid Build Coastguard Worker        for dims in [1, 3]:
461*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
462*da0073e9SAndroid Build Coastguard Worker                sparsity = self._random_float(0.1, 0.9)
463*da0073e9SAndroid Build Coastguard Worker                tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3)
464*da0073e9SAndroid Build Coastguard Worker                init.sparse_(tensor, sparsity)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
467*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("scipy.kstest is failing under dynamo")
468*da0073e9SAndroid Build Coastguard Worker    def test_sparse_default_std(self):
469*da0073e9SAndroid Build Coastguard Worker        for use_random_std in [True, False]:
470*da0073e9SAndroid Build Coastguard Worker            input_tensor = self._create_random_nd_tensor(2, size_min=30, size_max=35)
471*da0073e9SAndroid Build Coastguard Worker            rows, cols = input_tensor.size(0), input_tensor.size(1)
472*da0073e9SAndroid Build Coastguard Worker            sparsity = self._random_float(0.1, 0.2)
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker            std = 0.01  # default std
475*da0073e9SAndroid Build Coastguard Worker            if use_random_std:
476*da0073e9SAndroid Build Coastguard Worker                std = self._random_float(0.01, 0.2)
477*da0073e9SAndroid Build Coastguard Worker                init.sparse_(input_tensor, sparsity=sparsity, std=std)
478*da0073e9SAndroid Build Coastguard Worker            else:
479*da0073e9SAndroid Build Coastguard Worker                init.sparse_(input_tensor, sparsity=sparsity)
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker            for col_idx in range(input_tensor.size(1)):
482*da0073e9SAndroid Build Coastguard Worker                column = input_tensor[:, col_idx]
483*da0073e9SAndroid Build Coastguard Worker                assert column[column == 0].nelement() >= math.ceil(sparsity * rows)
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker            assert self._is_normal(input_tensor[input_tensor != 0], 0, std)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
488*da0073e9SAndroid Build Coastguard Worker    def test_orthogonal(self):
489*da0073e9SAndroid Build Coastguard Worker        for use_gain in [True, False]:
490*da0073e9SAndroid Build Coastguard Worker            for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]:
491*da0073e9SAndroid Build Coastguard Worker                input_tensor = torch.zeros(tensor_size)
492*da0073e9SAndroid Build Coastguard Worker                gain = 1.0
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker                if use_gain:
495*da0073e9SAndroid Build Coastguard Worker                    gain = self._random_float(0.1, 2)
496*da0073e9SAndroid Build Coastguard Worker                    init.orthogonal_(input_tensor, gain=gain)
497*da0073e9SAndroid Build Coastguard Worker                else:
498*da0073e9SAndroid Build Coastguard Worker                    init.orthogonal_(input_tensor)
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker                rows, cols = tensor_size[0], reduce(mul, tensor_size[1:])
501*da0073e9SAndroid Build Coastguard Worker                flattened_tensor = input_tensor.view(rows, cols)
502*da0073e9SAndroid Build Coastguard Worker                if rows > cols:
503*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
504*da0073e9SAndroid Build Coastguard Worker                        torch.mm(flattened_tensor.t(), flattened_tensor),
505*da0073e9SAndroid Build Coastguard Worker                        torch.eye(cols) * gain**2,
506*da0073e9SAndroid Build Coastguard Worker                        atol=1e-6,
507*da0073e9SAndroid Build Coastguard Worker                        rtol=0,
508*da0073e9SAndroid Build Coastguard Worker                    )
509*da0073e9SAndroid Build Coastguard Worker                else:
510*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
511*da0073e9SAndroid Build Coastguard Worker                        torch.mm(flattened_tensor, flattened_tensor.t()),
512*da0073e9SAndroid Build Coastguard Worker                        torch.eye(rows) * gain**2,
513*da0073e9SAndroid Build Coastguard Worker                        atol=1e-6,
514*da0073e9SAndroid Build Coastguard Worker                        rtol=0,
515*da0073e9SAndroid Build Coastguard Worker                    )
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker    def test_deprecation(self):
518*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        def fn():
521*da0073e9SAndroid Build Coastguard Worker            init.normal(x)
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
524*da0073e9SAndroid Build Coastguard Worker            FutureWarning,
525*da0073e9SAndroid Build Coastguard Worker            "deprecated",
526*da0073e9SAndroid Build Coastguard Worker            msg="methods not suffixed with underscore should be deprecated",
527*da0073e9SAndroid Build Coastguard Worker        ):
528*da0073e9SAndroid Build Coastguard Worker            fn()
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
532*da0073e9SAndroid Build Coastguard Worker    run_tests()
533