xref: /aosp_15_r20/external/pytorch/test/nn/test_lazy_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import pickle
3import unittest
4
5import torch
6import torch.nn as nn
7from torch.nn import Buffer, Parameter
8from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
9from torch.testing._internal.common_cuda import TEST_CUDA
10from torch.testing._internal.common_utils import (
11    run_tests,
12    suppress_warnings,
13    TEST_PRIVATEUSE1,
14    TestCase,
15)
16
17
18class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
19    pass
20
21
22class TestLazyModules(TestCase):
23    @suppress_warnings
24    def test_lazy_module_parameter(self):
25        module = LazyModule()
26        module.register_parameter("test_param", UninitializedParameter())
27        self.assertTrue(module.has_uninitialized_params())
28        state_dict = module.state_dict()
29        self.assertIsInstance(state_dict["test_param"], UninitializedParameter)
30        new_module = LazyModule()
31        # An error is raised when there is an attempt to replace an existing parameter
32        # with an uninitialized one
33        new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
34        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
35            new_module.load_state_dict(state_dict)
36        # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
37        new_module = LazyModule()
38        new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
39        module.load_state_dict(new_module.state_dict())
40        self.assertEqual(module.test_param, torch.ones((5, 5)))
41
42        # Uninitialized parameters are left unchanged
43        module = LazyModule()
44        module.register_parameter("test_param", UninitializedParameter())
45        self.assertTrue(module.has_uninitialized_params())
46
47        new_module = LazyModule()
48        new_module.register_parameter("test_param", UninitializedParameter())
49        module.load_state_dict(new_module.state_dict())
50        self.assertTrue(module.has_uninitialized_params())
51
52    @suppress_warnings
53    def test_lazy_module_buffer(self):
54        module = LazyModule()
55        module.test_buffer = UninitializedBuffer()
56        self.assertTrue(module.has_uninitialized_params())
57        state_dict = module.state_dict()
58        self.assertIsInstance(state_dict["test_buffer"], UninitializedBuffer)
59        new_module = LazyModule()
60        # An error is raised when there is an attempt to replace an existing parameter
61        # with an uninitialized one
62        new_module.test_buffer = Buffer(torch.ones(5, 5))
63        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
64            new_module.load_state_dict(state_dict)
65        # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
66        new_module = LazyModule()
67        new_module.test_buffer = Buffer(torch.ones(5, 5))
68        module.load_state_dict(new_module.state_dict())
69        self.assertEqual(module.test_buffer, torch.ones((5, 5)))
70
71        # Uninitialized parameters are left unchanged
72        module = LazyModule()
73        module.test_buffer = UninitializedBuffer()
74        self.assertTrue(module.has_uninitialized_params())
75
76        new_module = LazyModule()
77        new_module.test_buffer = UninitializedBuffer()
78        module.load_state_dict(new_module.state_dict())
79        module.load_state_dict(new_module.state_dict())
80        self.assertTrue(module.has_uninitialized_params())
81
82    @suppress_warnings
83    def test_lazy_module_jit_param(self):
84        module = LazyModule()
85        module.register_parameter("test_param", UninitializedParameter())
86        self.assertTrue(module.has_uninitialized_params())
87        with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
88            torch.jit.script(module)
89
90    @suppress_warnings
91    def test_lazy_module_jit_buffer(self):
92        module = LazyModule()
93        module.test_buffer = UninitializedBuffer()
94        self.assertTrue(module.has_uninitialized_params())
95        with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
96            torch.jit.script(module)
97
98    @suppress_warnings
99    def test_lazy_share_memory_param(self):
100        module = LazyModule()
101        module.register_parameter("test_param", UninitializedParameter())
102        self.assertTrue(module.has_uninitialized_params())
103        with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
104            module.share_memory()
105
106    @suppress_warnings
107    def test_lazy_share_memory_buffer(self):
108        module = LazyModule()
109        module.test_buffer = UninitializedBuffer()
110        self.assertTrue(module.has_uninitialized_params())
111        with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
112            module.share_memory()
113
114    @suppress_warnings
115    def test_linear(self):
116        module = nn.LazyLinear(10)
117        self.assertIsInstance(module.weight, UninitializedParameter)
118        self.assertIsInstance(module.bias, UninitializedParameter)
119        input = torch.ones(5, 5)
120        module(input)
121        self.assertIsInstance(module, nn.Linear)
122        self.assertNotIsInstance(module, nn.LazyLinear)
123        self.assertTrue(module.weight.shape == (10, 5))
124        self.assertTrue(module.bias.shape == (10,))
125        y = module(input)
126        self.assertTrue(
127            torch.equal(
128                torch.nn.functional.linear(input, module.weight, module.bias), y
129            )
130        )
131
132    @suppress_warnings
133    def test_lazy_linear_pickle(self):
134        module = nn.LazyLinear(10)
135        self.assertIsInstance(module.weight, UninitializedParameter)
136        self.assertIsInstance(module.bias, UninitializedParameter)
137        module = pickle.loads(pickle.dumps(module))
138        self.assertIsInstance(module, nn.LazyLinear)
139        self.assertIsInstance(module.weight, UninitializedParameter)
140        self.assertIsInstance(module.bias, UninitializedParameter)
141        input = torch.ones(5, 5)
142        module(input)  # fully materialized
143        new_module = pickle.loads(pickle.dumps(module))
144        self.assertIsInstance(new_module, nn.Linear)
145        self.assertNotIsInstance(new_module, nn.LazyLinear)
146        self.assertTrue(new_module.weight.shape == (10, 5))
147        self.assertNotIsInstance(new_module.weight, UninitializedParameter)
148        self.assertTrue(new_module.bias.shape == (10,))
149        self.assertNotIsInstance(new_module.bias, UninitializedParameter)
150
151    @suppress_warnings
152    def test_linear_state(self):
153        module = nn.Linear(5, 10)
154        lazy_module = nn.LazyLinear(10)
155        lazy_module.load_state_dict(module.state_dict())
156        # Parameters have been initialized but the module won't become a full
157        # Linear one until the first iteration. This is due to
158        # limitations on the state_dict loading logic
159        self.assertFalse(lazy_module.has_uninitialized_params())
160        self.assertTrue(lazy_module.weight.shape == (10, 5))
161        self.assertTrue(lazy_module.bias.shape == (10,))
162
163        module = nn.Linear(5, 10)
164        lazy_module = nn.LazyLinear(10)
165        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
166            module.load_state_dict(lazy_module.state_dict())
167
168    def _check_lazy_conv(
169        self,
170        cls,
171        lazy_cls,
172        func,
173        init_args,
174        input_shape,
175        expected_weight_shape,
176        expected_bias_shape,
177        *forward_args,
178        **forward_kwargs,
179    ):
180        module = lazy_cls(*init_args)
181        self.assertIsInstance(module.weight, UninitializedParameter)
182        if module.bias is not None:
183            self.assertIsInstance(module.bias, UninitializedParameter)
184        input = torch.ones(*input_shape)
185        module(input, *forward_args, **forward_kwargs)
186        self.assertIsInstance(module, cls)
187        self.assertNotIsInstance(module, lazy_cls)
188        self.assertEqual(module.weight.shape, expected_weight_shape)
189        if module.bias is not None:
190            self.assertEqual(module.bias.shape, expected_bias_shape)
191        y = module(input)
192        self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))
193
194    def _check_lazy_conv_pickle(
195        self,
196        cls,
197        lazy_cls,
198        init_args,
199        input_shape,
200        expected_weight_shape,
201        expected_bias_shape,
202    ):
203        module = lazy_cls(*init_args)
204        self.assertIsInstance(module.weight, UninitializedParameter)
205        if module.bias is not None:
206            self.assertIsInstance(module.bias, UninitializedParameter)
207        module = pickle.loads(pickle.dumps(module))
208        self.assertIsInstance(module, lazy_cls)
209        self.assertIsInstance(module.weight, UninitializedParameter)
210        if module.bias is not None:
211            self.assertIsInstance(module.bias, UninitializedParameter)
212        input = torch.ones(*input_shape)
213        module(input)  # fully materialized
214        new_module = pickle.loads(pickle.dumps(module))
215        self.assertIsInstance(new_module, cls)
216        self.assertNotIsInstance(new_module, lazy_cls)
217        self.assertEqual(new_module.weight.shape, expected_weight_shape)
218        self.assertNotIsInstance(new_module.weight, UninitializedParameter)
219        if new_module.bias is not None:
220            self.assertEqual(new_module.bias.shape, expected_bias_shape)
221            self.assertNotIsInstance(new_module.bias, UninitializedParameter)
222
223    def _check_lazy_conv_state(
224        self, gen_module, gen_lazy_module, expected_weight_shape, expected_bias_shape
225    ):
226        module = gen_module()
227        lazy_module = gen_lazy_module()
228        lazy_module.load_state_dict(module.state_dict())
229        # Parameters have been initialized but the module won't become a full
230        # Conv one until the first iteration. This is due to
231        # limitations on the state_dict loading logic
232        self.assertFalse(lazy_module.has_uninitialized_params())
233        self.assertEqual(lazy_module.weight.shape, expected_weight_shape)
234        if lazy_module.bias is not None:
235            self.assertEqual(lazy_module.bias.shape, expected_bias_shape)
236
237        module = gen_module()
238        lazy_module = gen_lazy_module()
239        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
240            module.load_state_dict(lazy_module.state_dict())
241
242    def test_lazy_pre_forward_hook(self):
243        """
244        This test is to test whether lazymodule can register other pre-forward hook
245        functions successfully.
246        """
247
248        class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
249            def initialize_parameters(self, input):
250                return None
251
252            def forward(self, input):
253                return input
254
255        def hook_function(module, input):
256            return input[0] + 1
257
258        module = TestModule()
259        module.register_forward_pre_hook(hook_function)
260        output = module(torch.zeros(2, 2))
261        self.assertEqual(output, torch.ones(2, 2))
262
263    def test_lazy_forward_hook(self):
264        """
265        This test is to test whether lazymodule can register other forward hook
266        functions successfully.
267        """
268
269        class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
270            def initialize_parameters(self, input):
271                return None
272
273            def forward(self, input):
274                return input
275
276        def hook_function(module, input, output):
277            return input[0] + 1
278
279        module = TestModule()
280        module.register_forward_hook(hook_function)
281        output = module(torch.zeros(2, 2))
282        self.assertEqual(output, torch.ones(2, 2))
283
284    @suppress_warnings
285    def test_lazy_conv1d(self):
286        self._check_lazy_conv(
287            nn.Conv1d,
288            nn.LazyConv1d,
289            torch.nn.functional.conv1d,
290            (32, 2),
291            (192, 16, 50),
292            (32, 16, 2),
293            (32,),
294        )
295
296    @suppress_warnings
297    def test_lazy_conv1d_pickle(self):
298        self._check_lazy_conv_pickle(
299            nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50), (32, 16, 2), (32,)
300        )
301
302    @suppress_warnings
303    def test_lazy_conv1d_state(self):
304        self._check_lazy_conv_state(
305            lambda: nn.Conv1d(16, 32, 2),
306            lambda: nn.LazyConv1d(32, 2),
307            (32, 16, 2),
308            (32,),
309        )
310
311    @suppress_warnings
312    def test_lazy_conv2d(self):
313        self._check_lazy_conv(
314            nn.Conv2d,
315            nn.LazyConv2d,
316            torch.nn.functional.conv2d,
317            (32, 2),
318            (192, 16, 8, 6),
319            (32, 16, 2, 2),
320            (32,),
321        )
322
323    @suppress_warnings
324    def test_lazy_conv2d_pickle(self):
325        self._check_lazy_conv_pickle(
326            nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,)
327        )
328
329    @suppress_warnings
330    def test_lazy_conv2d_state(self):
331        self._check_lazy_conv_state(
332            lambda: nn.Conv2d(16, 32, 2),
333            lambda: nn.LazyConv2d(32, 2),
334            (32, 16, 2, 2),
335            (32,),
336        )
337
338    @suppress_warnings
339    def test_lazy_conv3d(self):
340        self._check_lazy_conv(
341            nn.Conv3d,
342            nn.LazyConv3d,
343            torch.nn.functional.conv3d,
344            (32, 2),
345            (192, 16, 8, 7, 6),
346            (32, 16, 2, 2, 2),
347            (32,),
348        )
349
350    @suppress_warnings
351    def test_lazy_conv3d_pickle(self):
352        self._check_lazy_conv_pickle(
353            nn.Conv3d,
354            nn.LazyConv3d,
355            (32, 2),
356            (192, 16, 8, 7, 6),
357            (32, 16, 2, 2, 2),
358            (32,),
359        )
360
361    @suppress_warnings
362    def test_lazy_conv3d_state(self):
363        self._check_lazy_conv_state(
364            lambda: nn.Conv3d(16, 32, 2),
365            lambda: nn.LazyConv3d(32, 2),
366            (32, 16, 2, 2, 2),
367            (32,),
368        )
369
370    @suppress_warnings
371    def test_lazy_conv_transposed1d(self):
372        self._check_lazy_conv(
373            nn.ConvTranspose1d,
374            nn.LazyConvTranspose1d,
375            torch.nn.functional.conv_transpose1d,
376            (32, 2),
377            (192, 16, 50),
378            (16, 32, 2),
379            (32,),
380        )
381
382    @suppress_warnings
383    def test_lazy_conv_transpose1d_kwargs(self):
384        self._check_lazy_conv(
385            nn.ConvTranspose1d,
386            nn.LazyConvTranspose1d,
387            torch.nn.functional.conv_transpose1d,
388            (32, 2),
389            (192, 16, 50),
390            (16, 32, 2),
391            (32,),
392            output_size=(51,),
393        )
394
395    @suppress_warnings
396    def test_lazy_conv_transpose1d_pickle(self):
397        self._check_lazy_conv_pickle(
398            nn.ConvTranspose1d,
399            nn.LazyConvTranspose1d,
400            (32, 2),
401            (192, 16, 50),
402            (16, 32, 2),
403            (32,),
404        )
405
406    @suppress_warnings
407    def test_lazy_conv_transpose1d_state(self):
408        self._check_lazy_conv_state(
409            lambda: nn.ConvTranspose1d(16, 32, 2),
410            lambda: nn.LazyConvTranspose1d(32, 2),
411            (16, 32, 2),
412            (32,),
413        )
414
415    @suppress_warnings
416    def test_lazy_conv_transpose2d(self):
417        self._check_lazy_conv(
418            nn.ConvTranspose2d,
419            nn.LazyConvTranspose2d,
420            torch.nn.functional.conv_transpose2d,
421            (32, 2),
422            (192, 16, 8, 6),
423            (16, 32, 2, 2),
424            (32,),
425        )
426
427    @suppress_warnings
428    def test_lazy_conv_transpose2d_kwargs(self):
429        self._check_lazy_conv(
430            nn.ConvTranspose2d,
431            nn.LazyConvTranspose2d,
432            torch.nn.functional.conv_transpose2d,
433            (32, 2),
434            (192, 16, 8, 6),
435            (16, 32, 2, 2),
436            (32,),
437            output_size=(9, 7),
438        )
439
440    @suppress_warnings
441    def test_lazy_conv_transpose2d_pickle(self):
442        self._check_lazy_conv_pickle(
443            nn.ConvTranspose2d,
444            nn.LazyConvTranspose2d,
445            (32, 2),
446            (192, 16, 8, 6),
447            (16, 32, 2, 2),
448            (32,),
449        )
450
451    @suppress_warnings
452    def test_lazy_conv_transpose2d_state(self):
453        self._check_lazy_conv_state(
454            lambda: nn.ConvTranspose2d(16, 32, 2),
455            lambda: nn.LazyConvTranspose2d(32, 2),
456            (16, 32, 2, 2),
457            (32,),
458        )
459
460    @suppress_warnings
461    def test_lazy_conv_transpose3d(self):
462        self._check_lazy_conv(
463            nn.ConvTranspose3d,
464            nn.LazyConvTranspose3d,
465            torch.nn.functional.conv_transpose3d,
466            (32, 2),
467            (192, 16, 8, 7, 6),
468            (16, 32, 2, 2, 2),
469            (32,),
470        )
471
472    @suppress_warnings
473    def test_lazy_conv_transpose3d_kwargs(self):
474        self._check_lazy_conv(
475            nn.ConvTranspose3d,
476            nn.LazyConvTranspose3d,
477            torch.nn.functional.conv_transpose3d,
478            (32, 2),
479            (192, 16, 8, 7, 6),
480            (16, 32, 2, 2, 2),
481            (32,),
482            output_size=(9, 8, 7),
483        )
484
485    @suppress_warnings
486    def test_lazy_conv_transpose3d_pickle(self):
487        self._check_lazy_conv_pickle(
488            nn.ConvTranspose3d,
489            nn.LazyConvTranspose3d,
490            (32, 2),
491            (192, 16, 8, 7, 6),
492            (16, 32, 2, 2, 2),
493            (32,),
494        )
495
496    @suppress_warnings
497    def test_lazy_conv_transpose3d_state(self):
498        self._check_lazy_conv_state(
499            lambda: nn.ConvTranspose3d(16, 32, 2),
500            lambda: nn.LazyConvTranspose3d(32, 2),
501            (16, 32, 2, 2, 2),
502            (32,),
503        )
504
505    def _check_lazy_norm(self, cls, lazy_cls, input_shape):
506        for affine in [False, True]:
507            for track_running_stats in [False, True]:
508                lazy_module = lazy_cls(
509                    affine=affine, track_running_stats=track_running_stats
510                )
511
512                if affine:
513                    self.assertIsInstance(lazy_module.weight, UninitializedParameter)
514                    self.assertIsInstance(lazy_module.bias, UninitializedParameter)
515                if track_running_stats:
516                    self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
517                    self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
518
519                input = torch.ones(*input_shape)
520                lazy_output = lazy_module(input)
521                self.assertIsInstance(lazy_module, cls)
522                self.assertNotIsInstance(lazy_module, lazy_cls)
523
524                num_features = input_shape[1]
525                module = cls(
526                    num_features, affine=affine, track_running_stats=track_running_stats
527                )
528                expected_output = module(input)
529
530                self.assertEqual(lazy_output, expected_output)
531                if module.weight is not None:
532                    self.assertEqual(lazy_module.weight.shape, module.weight.shape)
533                    self.assertEqual(lazy_module.weight, module.weight)
534                if module.bias is not None:
535                    self.assertEqual(lazy_module.bias.shape, module.bias.shape)
536                    self.assertEqual(lazy_module.bias, module.bias)
537                if module.running_mean is not None:
538                    self.assertEqual(
539                        lazy_module.running_mean.shape, module.running_mean.shape
540                    )
541                    self.assertEqual(lazy_module.running_mean, module.running_mean)
542                if module.running_var is not None:
543                    self.assertEqual(
544                        lazy_module.running_var.shape, module.running_var.shape
545                    )
546                    self.assertEqual(lazy_module.running_var, module.running_var)
547                if module.num_batches_tracked is not None:
548                    self.assertEqual(
549                        lazy_module.num_batches_tracked.shape,
550                        module.num_batches_tracked.shape,
551                    )
552                    self.assertEqual(
553                        lazy_module.num_batches_tracked, module.num_batches_tracked
554                    )
555
556    def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape):
557        for affine in [False, True]:
558            for track_running_stats in [False, True]:
559                module = lazy_cls(
560                    affine=affine, track_running_stats=track_running_stats
561                )
562                module = pickle.loads(pickle.dumps(module))
563
564                self.assertIsInstance(module, lazy_cls)
565                if affine:
566                    self.assertIsInstance(module.weight, UninitializedParameter)
567                    self.assertIsInstance(module.bias, UninitializedParameter)
568                if track_running_stats:
569                    self.assertIsInstance(module.running_mean, UninitializedBuffer)
570                    self.assertIsInstance(module.running_var, UninitializedBuffer)
571
572                input = torch.ones(*input_shape)
573                module(input)  # fully materialized
574                module = pickle.loads(pickle.dumps(module))
575
576                self.assertNotIsInstance(module, lazy_cls)
577                self.assertIsInstance(module, cls)
578                if affine:
579                    self.assertNotIsInstance(module.weight, UninitializedParameter)
580                    self.assertNotIsInstance(module.bias, UninitializedParameter)
581                if track_running_stats:
582                    self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
583                    self.assertNotIsInstance(module.running_var, UninitializedBuffer)
584
585    def _check_lazy_batchnorm_state(self, cls, lazy_cls):
586        module = cls(10)
587        lazy_module = lazy_cls(affine=True, track_running_stats=True)
588        lazy_module.load_state_dict(module.state_dict())
589        # Parameters have been initialized but the module won't become a full
590        # Conv one until the first iteration. This is due to
591        # limitations on the state_dict loading logic
592        self.assertFalse(lazy_module.has_uninitialized_params())
593        self.assertEqual(lazy_module.weight.shape, (10,))
594        self.assertEqual(lazy_module.bias.shape, (10,))
595        self.assertEqual(lazy_module.running_mean.shape, (10,))
596        self.assertEqual(lazy_module.running_var.shape, (10,))
597
598        module = cls(10)
599        lazy_module = lazy_cls()
600        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
601            module.load_state_dict(lazy_module.state_dict())
602
603    def _check_lazy_instancenorm_state(self, cls, lazy_cls):
604        for affine in [False, True]:
605            for track_running_stats in [False, True]:
606                module = cls(10, affine=affine, track_running_stats=track_running_stats)
607                lazy_module = lazy_cls(
608                    affine=affine, track_running_stats=track_running_stats
609                )
610                lazy_module.load_state_dict(module.state_dict())
611                # Parameters have been initialized but the module won't become a full
612                # InstanceNorm one until the first iteration. This is due to
613                # limitations on the state_dict loading logic
614                self.assertFalse(lazy_module.has_uninitialized_params())
615                if affine:
616                    self.assertEqual(lazy_module.weight.shape, (10,))
617                    self.assertEqual(lazy_module.bias.shape, (10,))
618                if track_running_stats:
619                    self.assertEqual(lazy_module.running_mean.shape, (10,))
620                    self.assertEqual(lazy_module.running_var.shape, (10,))
621
622        module = cls(10, affine=True, track_running_stats=True)
623        lazy_module = lazy_cls(affine=True, track_running_stats=True)
624        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
625            module.load_state_dict(lazy_module.state_dict())
626
627    def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape):
628        input = {"input": torch.ones(*input_shape)}
629
630        lazy_module = lazy_cls()
631        lazy_output = lazy_module(**input)
632
633        num_features = input_shape[1]
634        module = cls(num_features)
635        expected_output = module(**input)
636
637        self.assertEqual(lazy_output, expected_output)
638
639    def test_lazy_batchnorm1d(self):
640        self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
641        self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
642
643    def test_lazy_batchnorm1d_pickle(self):
644        self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
645        self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
646
647    def test_lazy_batchnorm1d_state(self):
648        self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
649        self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
650
651    def test_lazy_batchnorm2d(self):
652        self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
653
654    def test_lazy_batchnorm2d_pickle(self):
655        self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
656
657    def test_lazy_batchnorm2d_state(self):
658        self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
659        self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
660
661    def test_lazy_batchnorm3d(self):
662        self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
663
664    def test_lazy_batchnorm3d_pickle(self):
665        self._check_lazy_norm_pickle(
666            nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
667        )
668
669    def test_lazy_batchnorm3d_state(self):
670        self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
671        self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
672
673    def test_lazy_instancenorm1d(self):
674        self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
675
676    def test_lazy_instancenorm1d_pickle(self):
677        self._check_lazy_norm_pickle(
678            nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6)
679        )
680
681    def test_lazy_instancenorm1d_state(self):
682        self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
683        self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
684
685    def test_lazy_instancenorm2d(self):
686        self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
687
688    def test_lazy_instancenorm2d_pickle(self):
689        self._check_lazy_norm_pickle(
690            nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7)
691        )
692
693    def test_lazy_instancenorm2d_state(self):
694        self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
695        self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
696
697    def test_lazy_instancenorm3d(self):
698        self._check_lazy_norm(
699            nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
700        )
701
702    def test_lazy_instancenorm3d_pickle(self):
703        self._check_lazy_norm_pickle(
704            nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
705        )
706
707    def test_lazy_instancenorm3d_state(self):
708        self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
709        self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
710
711    def test_lazy_batchnorm_with_dict_input(self):
712        self._check_lazy_norm_with_dict_input(
713            nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)
714        )
715        self._check_lazy_norm_with_dict_input(
716            nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)
717        )
718        self._check_lazy_norm_with_dict_input(
719            nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
720        )
721
722    @suppress_warnings
723    def test_materialize_dtype(self):
724        module = LazyModule()
725        module.register_parameter("test_param", UninitializedParameter())
726        module.test_param.materialize(10)
727        self.assertTrue(module.test_param.dtype == torch.get_default_dtype())
728        module = LazyModule()
729        module.register_parameter("test_param", UninitializedParameter())
730        module.half()
731        module.test_param.materialize(10)
732        self.assertTrue(module.test_param.dtype == torch.float16)
733
734    @unittest.skipIf(
735        not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 not available"
736    )
737    @suppress_warnings
738    def test_materialize_device(self):
739        module = LazyModule()
740        module.register_parameter("test_param", UninitializedParameter())
741        module.test_param.materialize(10)
742        self.assertTrue(module.test_param.device.type == "cpu")
743        if TEST_CUDA:
744            device = "cuda"
745        elif TEST_PRIVATEUSE1:
746            device = torch._C._get_privateuse1_backend_name()
747        module = LazyModule()
748        module.register_parameter("test_param", UninitializedParameter())
749        module.to(device)
750        module.test_param.materialize(10)
751        self.assertTrue(module.test_param.device.type == device)
752
753    @suppress_warnings
754    def test_chained_initialization(self):
755        class MyNetwork(torch.nn.Module):
756            def __init__(self) -> None:
757                super().__init__()
758                self.linear_1 = torch.nn.LazyLinear(15)
759                self.linear_2 = torch.nn.LazyLinear(10)
760
761            def forward(self, x):
762                y = self.linear_1(x)
763                return self.linear_2(y)
764
765        net = MyNetwork()
766        net(torch.ones(5, 10))
767        self.assertTrue(net.linear_1.weight.shape == (15, 10))
768        self.assertTrue(net.linear_1.bias.shape == (15,))
769        self.assertTrue(net.linear_2.weight.shape == (10, 15))
770        self.assertTrue(net.linear_2.bias.shape == (10,))
771
772    @suppress_warnings
773    def test_optimizer_pass(self):
774        optimizers = [
775            torch.optim.Adadelta,
776            torch.optim.Adagrad,
777            torch.optim.Adamax,
778            torch.optim.Adam,
779            torch.optim.AdamW,
780            torch.optim.ASGD,
781            torch.optim.SGD,
782            torch.optim.Rprop,
783            torch.optim.RMSprop,
784            torch.optim.LBFGS,
785            torch.optim.NAdam,
786            torch.optim.RAdam,
787        ]
788
789        def run_step(module, optim):
790            self.assertIsInstance(
791                optim.param_groups[0]["params"][0], UninitializedParameter
792            )
793            module.test_param.materialize(10)
794            self.assertIsInstance(optim.param_groups[0]["params"][0], Parameter)
795            self.assertNotIsInstance(
796                optim.param_groups[0]["params"][0], UninitializedParameter
797            )
798            for p in module.parameters():
799                p.grad = torch.rand_like(p)
800            if isinstance(optim, torch.optim.LBFGS):
801                optim.step(lambda: 1.0)
802            else:
803                optim.step()
804
805        for optim_cls in optimizers:
806            module = LazyModule()
807            module.register_parameter("test_param", UninitializedParameter())
808            if optim_cls is torch.optim.SGD:
809                optim = optim_cls(module.parameters(), lr=0.0)
810            elif optim_cls is torch.optim.Adagrad:
811                with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
812                    optim = optim_cls(module.parameters())
813                continue
814            else:
815                optim = optim_cls(module.parameters())
816            run_step(module, optim)
817
818    @suppress_warnings
819    def test_weight_norm(self):
820        m = nn.LazyLinear(7)
821        with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
822            m = torch.nn.utils.weight_norm(m)
823
824    @suppress_warnings
825    def test_spectral_norm(self):
826        m = nn.LazyLinear(7)
827        with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
828            m = torch.nn.utils.spectral_norm(m)
829
830    @suppress_warnings
831    def test_invalid_functions(self):
832        param = torch.nn.parameter.UninitializedParameter()
833        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
834            torch.empty_like(param)
835
836        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
837            torch.add(param, param)
838
839        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
840            param + param
841
842
843if __name__ == "__main__":
844    run_tests()
845