# Owner(s): ["module: nn"] import tempfile from copy import deepcopy from functools import partial from unittest import expectedFailure import torch from torch import nn from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.utils.parametrize import ( register_parametrization, remove_parametrizations, ) from torch.testing._internal.common_subclass import ( DiagTensorBelow, subclass_db, ) from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, parametrize, run_tests, skipIfTorchDynamo, subtest, ) from torch.testing._internal.logging_tensor import LoggingTensor from torch.utils._pytree import tree_map # The current test methodology in this file is to test a variety of real use cases # with a set of fully-fledged tensor subclasses. In the future, this may change # to more narrowly specify toy subclasses for each of the specific invariants under # test, avoiding the need to maintain the set of fully-fledged tensor subclasses. # Decorator for parametrizing tests across the various tensor classes. parametrize_tensor_cls = parametrize("tensor_cls", [ subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()]) class TestSubclass(TestCase): def _create_tensor(self, tensor_cls): return subclass_db[tensor_cls].create_fn(3) @parametrize_tensor_cls @parametrize("tensor_requires_grad", [False, True]) def test_param_invariants(self, tensor_cls, tensor_requires_grad): x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad) param = nn.Parameter(x, requires_grad=(not tensor_requires_grad)) self.assertIsInstance(param, nn.Parameter) # Ensure requires_grad passed to Parameter's constructor takes precedence. self.assertEqual(param.requires_grad, not tensor_requires_grad) # Ensure original tensor is not mutated by Parameter construction. self.assertNotIsInstance(x, nn.Parameter) self.assertEqual(x.requires_grad, tensor_requires_grad) class UninitializedParam(nn.Parameter): pass self.assertNotIsInstance(param, UninitializedParam) @skipIfTorchDynamo() @parametrize_tensor_cls @parametrize("as_param", [False, True]) def test_deepcopy(self, tensor_cls, as_param): x = self._create_tensor(tensor_cls) if as_param: x = nn.Parameter(x) x_copy = deepcopy(x) self.assertEqual(x, x_copy) self.assertEqual(x.__class__, x_copy.__class__) self.assertIsNot(x, x_copy) self.assertIsInstance(x_copy, tensor_cls) if as_param: # Deepcopy should preserve both custom type and "parameter-ness". self.assertIsInstance(x_copy, nn.Parameter) @parametrize_tensor_cls @parametrize("as_param", [False, True]) def test_serialization(self, tensor_cls, as_param): with tempfile.TemporaryFile() as f: x = self._create_tensor(tensor_cls) if as_param: x = nn.Parameter(x) torch.save(x, f) f.seek(0) with torch.serialization.safe_globals([tensor_cls]): x_loaded = torch.load(f) self.assertEqual(x, x_loaded) self.assertIsNot(x, x_loaded) self.assertIsInstance(x_loaded, tensor_cls) if as_param: # Serialization should preserve both custom type and "parameter-ness". self.assertIsInstance(x_loaded, nn.Parameter) @skipIfTorchDynamo("Visible only with functorch as functorch monkeypatches tensor str") @parametrize_tensor_cls @parametrize("as_param", [False, True]) def test_repr(self, tensor_cls, as_param): x = self._create_tensor(tensor_cls) if as_param: x = nn.Parameter(x) str_repr = x.__repr__() if tensor_cls is not torch.Tensor: self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1) self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0) @parametrize_tensor_cls @parametrize("as_param", [False, True]) def test_type_propagation(self, tensor_cls, as_param): x = self._create_tensor(tensor_cls) if as_param: x = nn.Parameter(x) # Call the add operator to produce an output tensor. output = x + self._create_tensor(torch.Tensor) # Custom type should be propagated across operations if closed under the op, but # "parameter-ness" should not be. if subclass_db[tensor_cls].closed_under_ops: self.assertIsInstance(output, tensor_cls) else: self.assertIsInstance(output, torch.Tensor) self.assertNotIsInstance(output, nn.Parameter) @parametrize_tensor_cls def test_module_optimization(self, tensor_cls): create_fn = partial(self._create_tensor, tensor_cls) class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.p1 = nn.Parameter(create_fn()) self.p_list = nn.ParameterList([create_fn() for _ in range(3)]) self.p_list.append(create_fn()) self.p_dict = nn.ParameterDict({ 'foo': create_fn(), 'bar': create_fn(), }) self.p_dict['baz'] = create_fn() with torch.no_grad(): nn.init.normal_(self.p1) for p in self.p_list: nn.init.uniform_(p) for p in self.p_dict.values(): nn.init.uniform_(p) def forward(self, x): out = self.p1 + x for p in self.p_list: out = p + out for v in self.p_dict.values(): out = v + out return out m = MyModule() self.assertEqual(len(m.state_dict()), 8) optimizer = torch.optim.SGD(m.parameters(), lr=0.1) m(create_fn()).sum().backward(torch.tensor(1)) optimizer.step() @parametrize_tensor_cls @parametrize("leave_parametrized", [False, True]) def test_parametrization(self, tensor_cls, leave_parametrized): # TODO: Either implement set_() properly for these tensor subclasses or apply a # more general fix to avoid the need for special set_() handling. For now, skip # testing these as they're expected to fail. if tensor_cls in [LoggingTensor, DiagTensorBelow]: return create_fn = partial(self._create_tensor, tensor_cls) class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.weight = nn.Parameter(create_fn()) def forward(self, x): return self.weight + x class MyParametrization(nn.Module): def forward(self, X): return -X m = MyModule() self.assertEqual(len(m.state_dict()), 1) register_parametrization(m, 'weight', MyParametrization()) self.assertIsInstance(m.weight, tensor_cls) output = m(self._create_tensor(torch.Tensor)) self.assertIsInstance(output, tensor_cls) remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized) # Lazy modules with custom tensors are not supported yet. @expectedFailure @parametrize_tensor_cls def test_lazy_module(self, tensor_cls): if tensor_cls is torch.Tensor: self.fail('dummy fail for base tensor until the test passes for subclasses') class MyLazyModule(LazyModuleMixin, nn.Module): def __init__(self) -> None: super().__init__() self.param = nn.UninitializedParameter() def initialize_parameters(self, input) -> None: # type: ignore[override] if self.has_uninitialized_params(): with torch.no_grad(): self.param.materialize(input.shape) nn.init.uniform_(self.param) def forward(self, x): return self.param + x m = MyLazyModule() self.assertTrue(m.has_uninitialized_params()) output = m(self._create_tensor(tensor_cls)) self.assertFalse(m.has_uninitialized_params()) self.assertIsInstance(m.param, tensor_cls) def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self): # Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl. class NonRewrappingTensor(torch.Tensor): @staticmethod def __new__( cls, t: torch.Tensor ): r = super()._make_wrapper_subclass( cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device) return r def __init__(self, t) -> None: self.tensor: torch.Tensor = t @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e) -> torch.Tensor: if isinstance(e, NonRewrappingTensor): t = e.tensor return t else: return e r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) # Return an unwrapped tensor no longer of original subclass type. return r with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"): param = nn.Parameter(NonRewrappingTensor(torch.randn(3))) def test_tensor_subclass_storage_data_accesses_throw(self): from torch.testing._internal.logging_tensor import LoggingTensor x = torch.ones(2) x_log = LoggingTensor(x) # Accessing storage on a tensor subclass is valid storage = x_log.untyped_storage() # This includes accessing metadata on the storage sz = storage.size() # But storage methods that access data will throw with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): storage.data_ptr() with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): storage.resize_(0) with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): storage.copy_(storage) with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): storage.fill_(0) with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): storage._write_file("file") instantiate_parametrized_tests(TestSubclass) if __name__ == '__main__': run_tests()