1# Owner(s): ["module: nn"] 2 3import tempfile 4from copy import deepcopy 5from functools import partial 6from unittest import expectedFailure 7 8import torch 9from torch import nn 10from torch.nn.modules.lazy import LazyModuleMixin 11from torch.nn.utils.parametrize import ( 12 register_parametrization, 13 remove_parametrizations, 14) 15from torch.testing._internal.common_subclass import ( 16 DiagTensorBelow, 17 subclass_db, 18) 19from torch.testing._internal.common_utils import ( 20 TestCase, 21 instantiate_parametrized_tests, 22 parametrize, 23 run_tests, 24 skipIfTorchDynamo, 25 subtest, 26) 27from torch.testing._internal.logging_tensor import LoggingTensor 28from torch.utils._pytree import tree_map 29 30# The current test methodology in this file is to test a variety of real use cases 31# with a set of fully-fledged tensor subclasses. In the future, this may change 32# to more narrowly specify toy subclasses for each of the specific invariants under 33# test, avoiding the need to maintain the set of fully-fledged tensor subclasses. 34 35 36# Decorator for parametrizing tests across the various tensor classes. 37parametrize_tensor_cls = parametrize("tensor_cls", [ 38 subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()]) 39 40 41class TestSubclass(TestCase): 42 def _create_tensor(self, tensor_cls): 43 return subclass_db[tensor_cls].create_fn(3) 44 45 @parametrize_tensor_cls 46 @parametrize("tensor_requires_grad", [False, True]) 47 def test_param_invariants(self, tensor_cls, tensor_requires_grad): 48 x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad) 49 param = nn.Parameter(x, requires_grad=(not tensor_requires_grad)) 50 51 self.assertIsInstance(param, nn.Parameter) 52 # Ensure requires_grad passed to Parameter's constructor takes precedence. 53 self.assertEqual(param.requires_grad, not tensor_requires_grad) 54 55 # Ensure original tensor is not mutated by Parameter construction. 56 self.assertNotIsInstance(x, nn.Parameter) 57 self.assertEqual(x.requires_grad, tensor_requires_grad) 58 59 class UninitializedParam(nn.Parameter): 60 pass 61 62 self.assertNotIsInstance(param, UninitializedParam) 63 64 @skipIfTorchDynamo() 65 @parametrize_tensor_cls 66 @parametrize("as_param", [False, True]) 67 def test_deepcopy(self, tensor_cls, as_param): 68 x = self._create_tensor(tensor_cls) 69 if as_param: 70 x = nn.Parameter(x) 71 x_copy = deepcopy(x) 72 self.assertEqual(x, x_copy) 73 self.assertEqual(x.__class__, x_copy.__class__) 74 self.assertIsNot(x, x_copy) 75 self.assertIsInstance(x_copy, tensor_cls) 76 if as_param: 77 # Deepcopy should preserve both custom type and "parameter-ness". 78 self.assertIsInstance(x_copy, nn.Parameter) 79 80 @parametrize_tensor_cls 81 @parametrize("as_param", [False, True]) 82 def test_serialization(self, tensor_cls, as_param): 83 with tempfile.TemporaryFile() as f: 84 x = self._create_tensor(tensor_cls) 85 if as_param: 86 x = nn.Parameter(x) 87 torch.save(x, f) 88 f.seek(0) 89 with torch.serialization.safe_globals([tensor_cls]): 90 x_loaded = torch.load(f) 91 92 self.assertEqual(x, x_loaded) 93 self.assertIsNot(x, x_loaded) 94 self.assertIsInstance(x_loaded, tensor_cls) 95 if as_param: 96 # Serialization should preserve both custom type and "parameter-ness". 97 self.assertIsInstance(x_loaded, nn.Parameter) 98 99 @skipIfTorchDynamo("Visible only with functorch as functorch monkeypatches tensor str") 100 @parametrize_tensor_cls 101 @parametrize("as_param", [False, True]) 102 def test_repr(self, tensor_cls, as_param): 103 x = self._create_tensor(tensor_cls) 104 if as_param: 105 x = nn.Parameter(x) 106 str_repr = x.__repr__() 107 if tensor_cls is not torch.Tensor: 108 self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1) 109 self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0) 110 111 @parametrize_tensor_cls 112 @parametrize("as_param", [False, True]) 113 def test_type_propagation(self, tensor_cls, as_param): 114 x = self._create_tensor(tensor_cls) 115 if as_param: 116 x = nn.Parameter(x) 117 118 # Call the add operator to produce an output tensor. 119 output = x + self._create_tensor(torch.Tensor) 120 121 # Custom type should be propagated across operations if closed under the op, but 122 # "parameter-ness" should not be. 123 if subclass_db[tensor_cls].closed_under_ops: 124 self.assertIsInstance(output, tensor_cls) 125 else: 126 self.assertIsInstance(output, torch.Tensor) 127 self.assertNotIsInstance(output, nn.Parameter) 128 129 @parametrize_tensor_cls 130 def test_module_optimization(self, tensor_cls): 131 create_fn = partial(self._create_tensor, tensor_cls) 132 133 class MyModule(nn.Module): 134 def __init__(self) -> None: 135 super().__init__() 136 self.p1 = nn.Parameter(create_fn()) 137 138 self.p_list = nn.ParameterList([create_fn() for _ in range(3)]) 139 self.p_list.append(create_fn()) 140 141 self.p_dict = nn.ParameterDict({ 142 'foo': create_fn(), 143 'bar': create_fn(), 144 }) 145 self.p_dict['baz'] = create_fn() 146 147 with torch.no_grad(): 148 nn.init.normal_(self.p1) 149 for p in self.p_list: 150 nn.init.uniform_(p) 151 for p in self.p_dict.values(): 152 nn.init.uniform_(p) 153 154 def forward(self, x): 155 out = self.p1 + x 156 for p in self.p_list: 157 out = p + out 158 159 for v in self.p_dict.values(): 160 out = v + out 161 162 return out 163 164 m = MyModule() 165 self.assertEqual(len(m.state_dict()), 8) 166 167 optimizer = torch.optim.SGD(m.parameters(), lr=0.1) 168 m(create_fn()).sum().backward(torch.tensor(1)) 169 optimizer.step() 170 171 @parametrize_tensor_cls 172 @parametrize("leave_parametrized", [False, True]) 173 def test_parametrization(self, tensor_cls, leave_parametrized): 174 # TODO: Either implement set_() properly for these tensor subclasses or apply a 175 # more general fix to avoid the need for special set_() handling. For now, skip 176 # testing these as they're expected to fail. 177 if tensor_cls in [LoggingTensor, DiagTensorBelow]: 178 return 179 180 create_fn = partial(self._create_tensor, tensor_cls) 181 182 class MyModule(nn.Module): 183 def __init__(self) -> None: 184 super().__init__() 185 self.weight = nn.Parameter(create_fn()) 186 187 def forward(self, x): 188 return self.weight + x 189 190 class MyParametrization(nn.Module): 191 def forward(self, X): 192 return -X 193 194 m = MyModule() 195 self.assertEqual(len(m.state_dict()), 1) 196 register_parametrization(m, 'weight', MyParametrization()) 197 self.assertIsInstance(m.weight, tensor_cls) 198 output = m(self._create_tensor(torch.Tensor)) 199 self.assertIsInstance(output, tensor_cls) 200 remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized) 201 202 # Lazy modules with custom tensors are not supported yet. 203 @expectedFailure 204 @parametrize_tensor_cls 205 def test_lazy_module(self, tensor_cls): 206 if tensor_cls is torch.Tensor: 207 self.fail('dummy fail for base tensor until the test passes for subclasses') 208 209 class MyLazyModule(LazyModuleMixin, nn.Module): 210 def __init__(self) -> None: 211 super().__init__() 212 self.param = nn.UninitializedParameter() 213 214 def initialize_parameters(self, input) -> None: # type: ignore[override] 215 if self.has_uninitialized_params(): 216 with torch.no_grad(): 217 self.param.materialize(input.shape) 218 nn.init.uniform_(self.param) 219 220 def forward(self, x): 221 return self.param + x 222 223 m = MyLazyModule() 224 self.assertTrue(m.has_uninitialized_params()) 225 output = m(self._create_tensor(tensor_cls)) 226 self.assertFalse(m.has_uninitialized_params()) 227 self.assertIsInstance(m.param, tensor_cls) 228 229 def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self): 230 231 # Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl. 232 class NonRewrappingTensor(torch.Tensor): 233 @staticmethod 234 def __new__( 235 cls, t: torch.Tensor 236 ): 237 r = super()._make_wrapper_subclass( 238 cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device) 239 return r 240 241 def __init__(self, t) -> None: 242 self.tensor: torch.Tensor = t 243 244 @classmethod 245 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 246 247 def unwrap(e) -> torch.Tensor: 248 if isinstance(e, NonRewrappingTensor): 249 t = e.tensor 250 return t 251 else: 252 return e 253 254 r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 255 # Return an unwrapped tensor no longer of original subclass type. 256 return r 257 258 with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"): 259 param = nn.Parameter(NonRewrappingTensor(torch.randn(3))) 260 261 def test_tensor_subclass_storage_data_accesses_throw(self): 262 from torch.testing._internal.logging_tensor import LoggingTensor 263 x = torch.ones(2) 264 x_log = LoggingTensor(x) 265 # Accessing storage on a tensor subclass is valid 266 storage = x_log.untyped_storage() 267 # This includes accessing metadata on the storage 268 sz = storage.size() 269 # But storage methods that access data will throw 270 with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): 271 storage.data_ptr() 272 with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): 273 storage.resize_(0) 274 with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): 275 storage.copy_(storage) 276 with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): 277 storage.fill_(0) 278 with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): 279 storage._write_file("file") 280 281 282instantiate_parametrized_tests(TestSubclass) 283 284if __name__ == '__main__': 285 run_tests() 286