xref: /aosp_15_r20/external/pytorch/test/nn/test_load_state_dict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import re
3import unittest
4from copy import deepcopy
5from itertools import product
6
7import torch
8import torch.nn as nn
9from torch.testing._internal.common_nn import NNTestCase
10from torch.testing._internal.common_utils import (
11    instantiate_parametrized_tests,
12    parametrize,
13    run_tests,
14    skipIfCrossRef,
15    skipIfTorchDynamo,
16    swap,
17    TEST_NUMPY,
18    TestCase,
19)
20from torch.utils._pytree import tree_map
21
22
23if TEST_NUMPY:
24    import numpy as np
25
26
27class TestLoadStateDict(NNTestCase):
28    _do_cuda_memory_leak_check = True
29    _do_cuda_non_default_stream = True
30
31    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
32    @swap([True, False])
33    def test_load_state_dict_invalid(self):
34        m = torch.nn.Linear(2, 2, bias=False)
35
36        state_dict = {"weight": np.random.randn(2, 2)}
37        with self.assertRaisesRegex(
38            RuntimeError,
39            "expected torch.Tensor or Tensor-like object from checkpoint but received",
40        ):
41            m.load_state_dict(state_dict)
42
43        state_dict = {"weight": ((1.0, 1.0), (2.0, 2.0))}
44        with self.assertRaisesRegex(
45            RuntimeError,
46            "expected torch.Tensor or Tensor-like object from checkpoint but received",
47        ):
48            m.load_state_dict(state_dict)
49
50    @swap([True, False])
51    def test_load_state_dict_type(self):
52        m = nn.Module()
53
54        with self.assertRaisesRegex(
55            TypeError, "Expected state_dict to be dict-like, got"
56        ):
57            m.load_state_dict("")
58        with self.assertRaisesRegex(
59            TypeError, "Expected state_dict to be dict-like, got"
60        ):
61            m.load_state_dict(2)
62
63    @swap([True, False])
64    @skipIfTorchDynamo("dynamo installs weakrefs on some params")
65    def test_load_state_dict(self):
66        l = nn.Linear(5, 5)
67        block = nn.Module()
68        block.conv1 = nn.Conv2d(3, 3, 3, bias=True)
69        block.conv2 = nn.Conv2d(3, 3, 3, bias=False)
70        net = nn.Module()
71        net.linear1 = l
72        net.linear2 = l
73        net.bn = nn.BatchNorm2d(2)
74        net.block = block
75        net.add_module("empty", None)
76        conv1_bias_dtype = block.conv1.bias.dtype
77
78        state_dict = net.state_dict()
79        state_dict.update(
80            {
81                "linear1.weight": torch.ones(5, 5),
82                "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
83                "bn.running_mean": torch.randn(2),
84            }
85        )
86        # Also test if a DDP state_dict can be loaded from a local model.
87        ddp_state_dict = net.state_dict()
88        ddp_state_dict.update(
89            {
90                "module.linear1.weight": torch.ones(5, 5),
91                "module.block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
92                "module.bn.running_mean": torch.randn(2),
93            }
94        )
95        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
96            ddp_state_dict, "module."
97        )
98        for sd in [state_dict, ddp_state_dict]:
99            incompatible_keys = net.load_state_dict(sd)
100            self.assertEqual(len(incompatible_keys.missing_keys), 0)
101            self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
102            self.assertNotIn("Incompatible", str(incompatible_keys))
103            self.assertEqual(net.linear1.weight, sd["linear1.weight"])
104            self.assertEqual(net.block.conv1.bias, sd["block.conv1.bias"])
105            self.assertEqual(net.bn.running_mean, sd["bn.running_mean"])
106
107        state_dict = net.state_dict()
108        state_dict.update({"extra": torch.ones(5)})
109        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
110        incompatible_keys = net.load_state_dict(state_dict, strict=False)
111        self.assertEqual(len(incompatible_keys.missing_keys), 0)
112        self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
113        self.assertIn("extra", incompatible_keys.unexpected_keys)
114        self.assertIn("Incompatible", str(incompatible_keys))
115
116        state_dict = net.state_dict()
117        state_dict.update({"extra.param": torch.ones(5)})
118        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
119        incompatible_keys = net.load_state_dict(state_dict, strict=False)
120        self.assertEqual(len(incompatible_keys.missing_keys), 0)
121        self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
122        self.assertIn("extra.param", incompatible_keys.unexpected_keys)
123
124        state_dict = net.state_dict()
125        del state_dict["linear1.weight"]
126        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
127        incompatible_keys = net.load_state_dict(state_dict, strict=False)
128        self.assertEqual(len(incompatible_keys.missing_keys), 1)
129        self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
130        self.assertIn("linear1.weight", incompatible_keys.missing_keys)
131        state_dict.update({"extra.param": torch.ones(5)})
132        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
133        incompatible_keys = net.load_state_dict(state_dict, strict=False)
134        self.assertEqual(len(incompatible_keys.missing_keys), 1)
135        self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
136        self.assertIn("linear1.weight", incompatible_keys.missing_keys)
137        self.assertIn("extra.param", incompatible_keys.unexpected_keys)
138
139        state_dict = net.state_dict()
140        state_dict.update({"bn.running_mean": torch.rand(14, 4)})  # wrong size
141        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
142        self.assertRaises(
143            RuntimeError, lambda: net.load_state_dict(state_dict, strict=False)
144        )
145
146        state_dict = net.state_dict()
147        old_state_dict = deepcopy(state_dict)
148        state_dict = {
149            "linear1.weight": torch.ones(5, 5),
150            "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
151            "bn.running_mean": torch.randn(2),
152            "nonexistent_key": torch.rand(3),
153        }
154        net.load_state_dict(state_dict, strict=False)
155        self.assertEqual(net.linear1.weight, state_dict["linear1.weight"])
156        self.assertEqual(net.block.conv1.bias, state_dict["block.conv1.bias"])
157        self.assertEqual(net.bn.running_mean, state_dict["bn.running_mean"])
158        new_state_dict = net.state_dict()
159        del old_state_dict["linear1.weight"]
160        del old_state_dict["block.conv1.bias"]
161        del old_state_dict["bn.running_mean"]
162        for (
163            k,
164            v,
165        ) in old_state_dict.items():
166            self.assertTrue(v.equal(new_state_dict[k]))
167
168    @swap([True, False])
169    def test_load_state_dict_BC(self):
170        # BatchNormNd
171        # Added num_batches_tracked buffer at version 2. For state dict with
172        # earlier versions or no versions, it should provide default value of 0.
173        bn = nn.BatchNorm2d(3)
174        state_dict = bn.state_dict()
175        del state_dict["num_batches_tracked"]
176        state_dict._metadata[""]["version"] = 1  # version 1
177        bn.load_state_dict(state_dict)
178        self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
179        self.assertEqual(bn.num_batches_tracked.item(), 0)
180        del state_dict._metadata[""]["version"]  # no version
181        bn.load_state_dict(state_dict)
182        self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
183        self.assertEqual(bn.num_batches_tracked.item(), 0)
184
185    @swap([True, False])
186    def test_load_state_dict_child(self):
187        base_module = nn.Linear(1, 1)
188        model = base_module
189        for _ in range(3):
190            model = nn.Sequential(*[deepcopy(model) for _ in range(10)])
191
192        def hook_fn(
193            module,
194            state_dict,
195            prefix,
196            local_metadata,
197            strict,
198            missing_keys,
199            unexpected_keys,
200            error_msgs,
201        ):
202            module_state_dict = module.state_dict()
203            self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys()))
204
205        model[0][0].register_load_state_dict_pre_hook(hook_fn)
206        model.load_state_dict(model.state_dict(), strict=True)
207
208    # fails swapping as LSTM installs weak references on the parameters
209    @swap([False])
210    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
211    def test_load_state_dict_ref_cycle(self):
212        # load_state_dict shouldn't cause a reference cycle involving Tensors
213        import gc
214
215        m = torch.nn.LSTM(16, 16, bidirectional=True)
216
217        gc.collect()
218        m.load_state_dict(deepcopy(m).state_dict())
219        refcycles = gc.collect()
220
221        self.assertEqual(refcycles, 0)
222
223    @swap([True, False])
224    def test_load_state_dict_custom(self):
225        class CustomState(nn.Module):
226            def __init__(self) -> None:
227                super().__init__()
228                self.param = torch.nn.Parameter(torch.ones(1))
229                self.sub = torch.nn.Linear(5, 5)
230
231            def _save_to_state_dict(self, destination, prefix, keep_vars):
232                destination[prefix + "serialized"] = self.param.data + 1
233
234            def _load_from_state_dict(
235                self,
236                state_dict,
237                prefix,
238                local_metadata,
239                strict,
240                missing_keys,
241                unexpected_keys,
242                error_msgs,
243            ):
244                # skip some of the error handling
245                self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
246
247        # use sequential to verify nesting
248        m = nn.Sequential(CustomState())
249        with torch.no_grad():
250            m[0].param[0] = 10
251            m[0].sub.weight[0, 0] = 555
252        state_dict = m.state_dict()
253        self.assertEqual(state_dict["0.serialized"].item(), 11)
254        self.assertIn("0.sub.weight", state_dict)
255        self.assertNotIn("0.param", state_dict)
256        del m
257        mm = nn.Sequential(CustomState())
258        self.assertEqual(mm[0].param[0].item(), 1)
259        mm.load_state_dict(state_dict)
260        self.assertEqual(mm[0].param[0].item(), 10)
261        self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
262
263    @swap([True, False])
264    @parametrize("keep_vars", [True, False])
265    def test_load_state_dict_assign_meta(self, keep_vars):
266        class MyModule(torch.nn.Module):
267            def __init__(self) -> None:
268                super().__init__()
269                self.fc1 = nn.Linear(3, 5)
270                self.bn = nn.BatchNorm1d(5)
271                self.x = nn.Parameter(torch.rand(5), requires_grad=False)
272
273            def forward(self, input):
274                return self.x + self.bn(self.fc1(input))
275
276        swap = torch.__future__.get_swap_module_params_on_conversion()
277        net = MyModule()
278        state_dict = net.state_dict(keep_vars=keep_vars)
279        for v in state_dict.values():
280            v.requires_grad_(False)
281
282        with torch.device("meta"):
283            net_meta = MyModule()
284
285        net_meta_state_dict_old = net_meta.state_dict(keep_vars=True)
286        net_meta.load_state_dict(state_dict, assign=True)
287
288        # Make sure parameters and persistent buffers were assigned
289        net_meta_state_dict = net_meta.state_dict(keep_vars=True)
290        for key in state_dict.keys():
291            if key in net_meta._parameters:
292                if keep_vars and not swap:
293                    # state_dict[key] is an nn.Parameter
294                    self.assertTrue(state_dict[key] is net_meta_state_dict[key])
295                else:
296                    if swap:
297                        self.assertTrue(
298                            net_meta_state_dict[key] is net_meta_state_dict_old[key]
299                        )
300                    else:
301                        # state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
302                        self.assertTrue(
303                            net_meta_state_dict[key] is not net_meta_state_dict_old[key]
304                        )
305                        self.assertEqual(
306                            net_meta_state_dict_old[key].requires_grad,
307                            net_meta_state_dict[key].requires_grad,
308                        )
309                self.assertEqual(
310                    net_meta_state_dict_old[key].requires_grad,
311                    net_meta_state_dict[key].requires_grad,
312                )
313                self.assertEqual(state_dict[key], net_meta_state_dict[key])
314            elif (
315                key in net_meta._buffers
316                and key not in net_meta._non_persistent_buffers_set
317            ):
318                self.assertTrue(state_dict[key] is net_meta_state_dict[key])
319                self.assertEqual(state_dict[key], net_meta_state_dict[key])
320
321        # Make sure that ordering of parameters and buffers is preserved
322        net_named_parameters = net.named_parameters()
323        net_named_buffers = net.named_buffers()
324        net_meta_named_parameters = net_meta.named_parameters()
325        net_meta_named_buffers = net_meta.named_buffers()
326
327        for (n1, _), (n2, _) in zip(net_named_parameters, net_meta_named_parameters):
328            self.assertEqual(n1, n2)
329
330        for (n1, _), (n2, _) in zip(net_named_buffers, net_meta_named_buffers):
331            self.assertEqual(n1, n2)
332
333        # Make sure outputs are the same
334        t = torch.randn(4, 3)
335        out_net = net(t)
336        out_net_meta = net_meta(t.clone())
337
338        self.assertEqual(out_net, out_net_meta)
339
340    @swap([True, False])
341    def test_load_state_dict_assign_with_optimizer(self):
342        class MyModule(torch.nn.Module):
343            def __init__(self) -> None:
344                super().__init__()
345                self.fc1 = nn.Linear(3, 5)
346                self.bn = nn.BatchNorm1d(5)
347
348            def forward(self, input):
349                return self.bn(self.fc1(input))
350
351        net = MyModule()
352        opt = torch.optim.Adam(net.parameters(), lr=1000)
353        x = torch.randn(4, 3)
354        num_iters = 3
355
356        for i in range(num_iters):
357            opt.zero_grad()
358            out = net(x)
359            out.sum().backward()
360            opt.step()
361
362        opt_state_dict = deepcopy(opt.state_dict())
363        net_state_dict = deepcopy(net.state_dict())
364
365        with torch.device("meta"):
366            net_meta = MyModule()
367
368        net_meta.load_state_dict(net_state_dict, assign=True)
369        # must create optimizer only after loading state_dict when assign=True
370        opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
371        opt2.load_state_dict(opt_state_dict)
372
373        y = x.clone()
374        for i in range(num_iters):
375            opt.zero_grad()
376            out = net(x)
377            out.sum().backward()
378            opt.step()
379
380            opt2.zero_grad()
381            out2 = net_meta(y)
382            out2.sum().backward()
383            opt2.step()
384
385        self.assertEqual(opt.state_dict(), opt2.state_dict())
386        self.assertEqual(net.state_dict(), net_meta.state_dict())
387
388    @swap([True, False])
389    def test_load_state_dict_assign_shape_stride(self):
390        # Assigned tensor is allowed to have different properties than initial
391        # tensor except for shape
392        class MyModule(torch.nn.Module):
393            def __init__(self) -> None:
394                super().__init__()
395                self.fc1 = nn.Linear(3, 5)
396                self.bn = nn.BatchNorm1d(5)
397
398            def forward(self, input):
399                return self.bn(self.fc1(input))
400
401        net = MyModule()
402        state_dict = net.state_dict()
403        # loading should be ok if stride is different
404        state_dict["fc1.weight"] = torch.randn(3, 5).transpose(0, 1)
405        net2 = MyModule()
406        net2.load_state_dict(state_dict, strict=False, assign=True)
407
408        state_dict["fc1.weight"] = torch.randn(2, 4)
409        with self.assertRaisesRegex(
410            RuntimeError, "size mismatch for fc1.weight: copying a param with shape"
411        ):
412            net2.load_state_dict(state_dict, strict=False, assign=True)
413
414    @swap([True, False])
415    def test_load_state_dict_warn_assign(self):
416        with torch.device("meta"):
417            m = torch.nn.Linear(3, 5)
418        state_dict = m.state_dict()
419        state_dict["weight"] = torch.empty_like(state_dict["weight"], device="cpu")
420        with self.assertWarnsRegex(
421            UserWarning,
422            "for weight: copying from a non-meta parameter in the checkpoint to a meta",
423        ):
424            m.load_state_dict(state_dict)
425
426    @swap([True, False])
427    def test_load_state_dict_with_unexpected_key(self):
428        class MyModule(torch.nn.Module):
429            def __init__(self) -> None:
430                super().__init__()
431                self.fc1 = torch.nn.Linear(5, 10)
432
433        m = MyModule()
434
435        # Unexpected key & strict = True
436        with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
437            state_dict = m.state_dict()
438            state_dict["fc1.bad_suffix"] = torch.randn(5, 10)
439            m.load_state_dict(state_dict)
440
441        # Unexpected key & strict = False
442        state_dict = m.load_state_dict(state_dict, strict=False)
443        self.assertIn("fc1.bad_suffix", state_dict.unexpected_keys)
444
445        # Unexpected key whose prefix matches a valid key & strict = True
446        with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
447            state_dict = m.state_dict()
448            state_dict["fc1.weight.bad_suffix"] = torch.randn(5, 10)
449            m.load_state_dict(state_dict)
450
451        # Unexpected key whose prefix matches a valid key & strict = False
452        state_dict = m.load_state_dict(state_dict, strict=False)
453        self.assertIn("fc1.weight.bad_suffix", state_dict.unexpected_keys)
454
455
456def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
457    kwargs = {} if kwargs is None else kwargs
458
459    def module_load(dest, src, assign=False):
460        if isinstance(dest, cls):
461            if assign:
462                return src.detach()
463            else:
464                if type(src) is torch.Tensor:
465                    return cls(src)
466                elif type(src) is cls:
467                    return src.detach()
468                else:
469                    if isinstance(src, MyWrapperLoadTensor):
470                        return cls(src._data)
471                    return cls(src)
472        else:
473            assert isinstance(
474                src, cls
475            ), f"Expected isinstance(src, {cls}) but got {type(src)}"
476            assert (
477                type(dest) == torch.Tensor
478                or type(dest) == torch.nn.Parameter
479                or issubclass(cls, type(dest))
480            )
481            if assign:
482                return src.detach()
483            else:
484                if isinstance(src, MyWrapperLoadTensor):
485                    if type(dest) not in {torch.Tensor, torch.nn.Parameter}:
486                        return type(dest)(src._data)
487                    else:
488                        return src._data.detach()
489                else:
490                    return torch.Tensor(src)
491
492    if func is torch.Tensor.module_load:
493        return module_load(*args, **kwargs)
494    else:
495        with torch._C.DisableTorchFunctionSubclass():
496            # detach must return instance of same subclass for nn.Parameter()
497            if func == torch.Tensor.detach:
498                ret = func(*args, **kwargs)
499                if not isinstance(ret, cls):
500                    return cls(ret)
501                return ret
502            return func(*args, **kwargs)
503
504
505class MyLoadTensor(torch.Tensor):
506    @classmethod
507    def __torch_function__(cls, func, types, args=(), kwargs=None):
508        return load_torch_function_handler(cls, func, types, args, kwargs)
509
510
511# We use MyLoadTensor2 to test tensor subclass, wrapper tensor subclass
512# where neither inherits from each other
513class MyLoadTensor2(torch.Tensor):
514    @classmethod
515    def __torch_function__(cls, func, types, args=(), kwargs=None):
516        return load_torch_function_handler(cls, func, types, args, kwargs)
517
518
519class MyBrokenLoadTensor(torch.Tensor):
520    @classmethod
521    def __torch_function__(cls, func, types, args=(), kwargs=None):
522        kwargs = {} if kwargs is None else kwargs
523
524        if func is torch.Tensor.module_load:
525            # wrong as this doesn't detach!
526            return args[1]
527        else:
528            with torch._C.DisableTorchFunctionSubclass():
529                # detach must return instance of same subclass for nn.Parameter()
530                if func == torch.Tensor.detach:
531                    return cls(func(*args, **kwargs))
532                return func(*args, **kwargs)
533
534
535class MyWrapperLoadTensor(MyLoadTensor):
536    @staticmethod
537    def __new__(cls, data: torch.Tensor):
538        t = torch.Tensor._make_wrapper_subclass(
539            cls,
540            data.size(),
541            dtype=data.dtype,
542            layout=data.layout,
543            device=data.device,
544            requires_grad=data.requires_grad,
545            strides=data.stride(),
546            storage_offset=data.storage_offset(),
547        )
548        return t
549
550    def __init__(self, data: torch.Tensor):
551        self._data = data
552
553    def __repr__(self):
554        return f"MyWrapperLoadTensor({self._data.__repr__()})"
555
556    @classmethod
557    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
558        def unwrap(t):
559            return t._data if isinstance(t, MyWrapperLoadTensor) else t
560
561        def wrap(t):
562            return MyWrapperLoadTensor(t) if isinstance(t, torch.Tensor) else t
563
564        kwargs = {} if kwargs is None else kwargs
565        out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
566        return tree_map(wrap, out)
567
568
569class TestLoadStateDictSwap(TestCase):
570    @skipIfCrossRef
571    @skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs")
572    @swap([True])
573    @parametrize("assign", [True, False])
574    def test_swap_subclass(self, assign):
575        def _create_model(subclass=None):
576            m = torch.nn.Linear(2, 3, bias=False)
577            m.buf = torch.nn.Buffer(torch.randn(2, 3))
578            if subclass is not None:
579                m.weight = torch.nn.Parameter(subclass(m.weight))
580                m.buf = subclass(m.buf)
581            return m
582
583        def _test(m_subclass=None, sd_subclass=None):
584            m = _create_model(m_subclass)
585            sd = _create_model(sd_subclass).state_dict()
586            m.load_state_dict(sd, assign=assign)
587            self.assertEqual(m.weight, sd["weight"])
588            self.assertEqual(m.buf, sd["buf"])
589            self.assertTrue(isinstance(m.weight, torch.nn.Parameter))
590            self.assertTrue(not isinstance(m.buf, torch.nn.Parameter))
591
592            weight_type, buf_type = (torch.nn.Parameter, torch.Tensor)
593            if assign:
594                if sd_subclass is not None:
595                    weight_type, buf_type = (sd_subclass, sd_subclass)
596            else:
597                if m_subclass is not None:
598                    weight_type, buf_type = (m_subclass, m_subclass)
599
600            self.assertTrue(type(m.weight) is weight_type)
601            self.assertTrue(type(m.buf) is buf_type)
602
603        # (MyLoadTensor, MyWrapperLoadTensor) tests the behavior of (superclass, subclass)
604        subclasses = [None, MyLoadTensor, MyLoadTensor2, MyWrapperLoadTensor]
605        for m_s, sd_s in product(subclasses, subclasses):
606            _test(m_s, sd_s)
607
608        # MyBrokenLoadTensor should error since its module_load doesn't call .detach()
609        with self.assertRaisesRegex(
610            RuntimeError, re.escape("Error(s) in loading state_dict for Linear:")
611        ):
612            _test(None, MyBrokenLoadTensor)
613
614
615instantiate_parametrized_tests(TestLoadStateDict)
616instantiate_parametrized_tests(TestLoadStateDictSwap)
617
618if __name__ == "__main__":
619    TestCase._default_dtype_check_enabled = True
620    run_tests()
621