xref: /aosp_15_r20/external/pytorch/test/test_subclass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2
3import tempfile
4from copy import deepcopy
5from functools import partial
6from unittest import expectedFailure
7
8import torch
9from torch import nn
10from torch.nn.modules.lazy import LazyModuleMixin
11from torch.nn.utils.parametrize import (
12    register_parametrization,
13    remove_parametrizations,
14)
15from torch.testing._internal.common_subclass import (
16    DiagTensorBelow,
17    subclass_db,
18)
19from torch.testing._internal.common_utils import (
20    TestCase,
21    instantiate_parametrized_tests,
22    parametrize,
23    run_tests,
24    skipIfTorchDynamo,
25    subtest,
26)
27from torch.testing._internal.logging_tensor import LoggingTensor
28from torch.utils._pytree import tree_map
29
30# The current test methodology in this file is to test a variety of real use cases
31# with a set of fully-fledged tensor subclasses. In the future, this may change
32# to more narrowly specify toy subclasses for each of the specific invariants under
33# test, avoiding the need to maintain the set of fully-fledged tensor subclasses.
34
35
36# Decorator for parametrizing tests across the various tensor classes.
37parametrize_tensor_cls = parametrize("tensor_cls", [
38    subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()])
39
40
41class TestSubclass(TestCase):
42    def _create_tensor(self, tensor_cls):
43        return subclass_db[tensor_cls].create_fn(3)
44
45    @parametrize_tensor_cls
46    @parametrize("tensor_requires_grad", [False, True])
47    def test_param_invariants(self, tensor_cls, tensor_requires_grad):
48        x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad)
49        param = nn.Parameter(x, requires_grad=(not tensor_requires_grad))
50
51        self.assertIsInstance(param, nn.Parameter)
52        # Ensure requires_grad passed to Parameter's constructor takes precedence.
53        self.assertEqual(param.requires_grad, not tensor_requires_grad)
54
55        # Ensure original tensor is not mutated by Parameter construction.
56        self.assertNotIsInstance(x, nn.Parameter)
57        self.assertEqual(x.requires_grad, tensor_requires_grad)
58
59        class UninitializedParam(nn.Parameter):
60            pass
61
62        self.assertNotIsInstance(param, UninitializedParam)
63
64    @skipIfTorchDynamo()
65    @parametrize_tensor_cls
66    @parametrize("as_param", [False, True])
67    def test_deepcopy(self, tensor_cls, as_param):
68        x = self._create_tensor(tensor_cls)
69        if as_param:
70            x = nn.Parameter(x)
71        x_copy = deepcopy(x)
72        self.assertEqual(x, x_copy)
73        self.assertEqual(x.__class__, x_copy.__class__)
74        self.assertIsNot(x, x_copy)
75        self.assertIsInstance(x_copy, tensor_cls)
76        if as_param:
77            # Deepcopy should preserve both custom type and "parameter-ness".
78            self.assertIsInstance(x_copy, nn.Parameter)
79
80    @parametrize_tensor_cls
81    @parametrize("as_param", [False, True])
82    def test_serialization(self, tensor_cls, as_param):
83        with tempfile.TemporaryFile() as f:
84            x = self._create_tensor(tensor_cls)
85            if as_param:
86                x = nn.Parameter(x)
87            torch.save(x, f)
88            f.seek(0)
89            with torch.serialization.safe_globals([tensor_cls]):
90                x_loaded = torch.load(f)
91
92            self.assertEqual(x, x_loaded)
93            self.assertIsNot(x, x_loaded)
94            self.assertIsInstance(x_loaded, tensor_cls)
95            if as_param:
96                # Serialization should preserve both custom type and "parameter-ness".
97                self.assertIsInstance(x_loaded, nn.Parameter)
98
99    @skipIfTorchDynamo("Visible only with functorch as functorch monkeypatches tensor str")
100    @parametrize_tensor_cls
101    @parametrize("as_param", [False, True])
102    def test_repr(self, tensor_cls, as_param):
103        x = self._create_tensor(tensor_cls)
104        if as_param:
105            x = nn.Parameter(x)
106        str_repr = x.__repr__()
107        if tensor_cls is not torch.Tensor:
108            self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1)
109        self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0)
110
111    @parametrize_tensor_cls
112    @parametrize("as_param", [False, True])
113    def test_type_propagation(self, tensor_cls, as_param):
114        x = self._create_tensor(tensor_cls)
115        if as_param:
116            x = nn.Parameter(x)
117
118        # Call the add operator to produce an output tensor.
119        output = x + self._create_tensor(torch.Tensor)
120
121        # Custom type should be propagated across operations if closed under the op, but
122        # "parameter-ness" should not be.
123        if subclass_db[tensor_cls].closed_under_ops:
124            self.assertIsInstance(output, tensor_cls)
125        else:
126            self.assertIsInstance(output, torch.Tensor)
127        self.assertNotIsInstance(output, nn.Parameter)
128
129    @parametrize_tensor_cls
130    def test_module_optimization(self, tensor_cls):
131        create_fn = partial(self._create_tensor, tensor_cls)
132
133        class MyModule(nn.Module):
134            def __init__(self) -> None:
135                super().__init__()
136                self.p1 = nn.Parameter(create_fn())
137
138                self.p_list = nn.ParameterList([create_fn() for _ in range(3)])
139                self.p_list.append(create_fn())
140
141                self.p_dict = nn.ParameterDict({
142                    'foo': create_fn(),
143                    'bar': create_fn(),
144                })
145                self.p_dict['baz'] = create_fn()
146
147                with torch.no_grad():
148                    nn.init.normal_(self.p1)
149                    for p in self.p_list:
150                        nn.init.uniform_(p)
151                    for p in self.p_dict.values():
152                        nn.init.uniform_(p)
153
154            def forward(self, x):
155                out = self.p1 + x
156                for p in self.p_list:
157                    out = p + out
158
159                for v in self.p_dict.values():
160                    out = v + out
161
162                return out
163
164        m = MyModule()
165        self.assertEqual(len(m.state_dict()), 8)
166
167        optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
168        m(create_fn()).sum().backward(torch.tensor(1))
169        optimizer.step()
170
171    @parametrize_tensor_cls
172    @parametrize("leave_parametrized", [False, True])
173    def test_parametrization(self, tensor_cls, leave_parametrized):
174        # TODO: Either implement set_() properly for these tensor subclasses or apply a
175        # more general fix to avoid the need for special set_() handling. For now, skip
176        # testing these as they're expected to fail.
177        if tensor_cls in [LoggingTensor, DiagTensorBelow]:
178            return
179
180        create_fn = partial(self._create_tensor, tensor_cls)
181
182        class MyModule(nn.Module):
183            def __init__(self) -> None:
184                super().__init__()
185                self.weight = nn.Parameter(create_fn())
186
187            def forward(self, x):
188                return self.weight + x
189
190        class MyParametrization(nn.Module):
191            def forward(self, X):
192                return -X
193
194        m = MyModule()
195        self.assertEqual(len(m.state_dict()), 1)
196        register_parametrization(m, 'weight', MyParametrization())
197        self.assertIsInstance(m.weight, tensor_cls)
198        output = m(self._create_tensor(torch.Tensor))
199        self.assertIsInstance(output, tensor_cls)
200        remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized)
201
202    # Lazy modules with custom tensors are not supported yet.
203    @expectedFailure
204    @parametrize_tensor_cls
205    def test_lazy_module(self, tensor_cls):
206        if tensor_cls is torch.Tensor:
207            self.fail('dummy fail for base tensor until the test passes for subclasses')
208
209        class MyLazyModule(LazyModuleMixin, nn.Module):
210            def __init__(self) -> None:
211                super().__init__()
212                self.param = nn.UninitializedParameter()
213
214            def initialize_parameters(self, input) -> None:  # type: ignore[override]
215                if self.has_uninitialized_params():
216                    with torch.no_grad():
217                        self.param.materialize(input.shape)
218                        nn.init.uniform_(self.param)
219
220            def forward(self, x):
221                return self.param + x
222
223        m = MyLazyModule()
224        self.assertTrue(m.has_uninitialized_params())
225        output = m(self._create_tensor(tensor_cls))
226        self.assertFalse(m.has_uninitialized_params())
227        self.assertIsInstance(m.param, tensor_cls)
228
229    def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self):
230
231        # Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl.
232        class NonRewrappingTensor(torch.Tensor):
233            @staticmethod
234            def __new__(
235                cls, t: torch.Tensor
236            ):
237                r = super()._make_wrapper_subclass(
238                    cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device)
239                return r
240
241            def __init__(self, t) -> None:
242                self.tensor: torch.Tensor = t
243
244            @classmethod
245            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
246
247                def unwrap(e) -> torch.Tensor:
248                    if isinstance(e, NonRewrappingTensor):
249                        t = e.tensor
250                        return t
251                    else:
252                        return e
253
254                r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
255                # Return an unwrapped tensor no longer of original subclass type.
256                return r
257
258        with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"):
259            param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))
260
261    def test_tensor_subclass_storage_data_accesses_throw(self):
262        from torch.testing._internal.logging_tensor import LoggingTensor
263        x = torch.ones(2)
264        x_log = LoggingTensor(x)
265        # Accessing storage on a tensor subclass is valid
266        storage = x_log.untyped_storage()
267        # This includes accessing metadata on the storage
268        sz = storage.size()
269        # But storage methods that access data will throw
270        with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
271            storage.data_ptr()
272        with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
273            storage.resize_(0)
274        with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
275            storage.copy_(storage)
276        with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
277            storage.fill_(0)
278        with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
279            storage._write_file("file")
280
281
282instantiate_parametrized_tests(TestSubclass)
283
284if __name__ == '__main__':
285    run_tests()
286