1# Owner(s): ["oncall: pt2"] 2 3import tempfile 4import unittest 5 6import torch 7from torch._prims.debug_prims import load_tensor_reader 8from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 9from torch.multiprocessing.reductions import StorageWeakRef 10from torch.testing._internal.common_device_type import instantiate_device_type_tests 11from torch.testing._internal.common_utils import ( 12 IS_WINDOWS, 13 run_tests, 14 skipIfRocm, 15 TestCase, 16) 17from torch.utils._content_store import ( 18 ContentStoreReader, 19 ContentStoreWriter, 20 hash_storage, 21) 22 23 24@unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows") 25class TestContentStore(TestCase): 26 def test_basic(self, device): 27 # setup test data 28 x = torch.randn(4, device=device) 29 y = torch.randn(6, device=device) 30 z = x.view(2, 2) 31 # start writing 32 with tempfile.TemporaryDirectory() as loc: 33 writer = ContentStoreWriter(loc) 34 writer.write_tensor("x", x) 35 writer.write_tensor("y", y) 36 writer.write_tensor("z", z) 37 # do some mutation that is VC UNTRACKED 38 x.data.add_(1) 39 writer.write_tensor("x2", x) 40 writer.write_tensor("y2", y) 41 writer.write_tensor("z2", z) 42 del writer 43 44 reader = ContentStoreReader(loc) 45 n_x = reader.read_tensor("x") 46 n_y = reader.read_tensor("y") 47 n_z = reader.read_tensor("z") 48 self.assertEqual(n_x + 1, x) 49 self.assertEqual(n_y, y) 50 self.assertEqual(n_z + 1, z) 51 self.assertEqual( 52 StorageWeakRef(n_x.untyped_storage()), 53 StorageWeakRef(n_z.untyped_storage()), 54 ) 55 n_x2 = reader.read_tensor("x2") 56 n_y2 = reader.read_tensor("y2") 57 n_z2 = reader.read_tensor("z2") 58 self.assertEqual(n_x2, x) 59 self.assertEqual(n_y2, y) 60 self.assertEqual(n_z2, z) 61 self.assertEqual( 62 StorageWeakRef(n_y2.untyped_storage()), 63 StorageWeakRef(n_y.untyped_storage()), 64 ) 65 66 def test_scalar(self, device): 67 # Should not raise an error 68 hash_storage(torch.tensor(2, device=device).untyped_storage()) 69 70 @torch._dynamo.config.patch(cache_size_limit=1) 71 def test_repeated_hash(self, device): 72 # Test that repeated hashing doesn't trigger a recompile in dynamo 73 # If it does, we will execute prims.xor_sum in eager which fails 74 for _ in range(4): 75 hash_storage(torch.tensor(2, device=device).untyped_storage()) 76 77 @skipIfRocm 78 def test_load_tensor(self, device): 79 with tempfile.TemporaryDirectory() as loc: 80 writer = ContentStoreWriter(loc) 81 x = torch.randn(4, device=device) 82 83 def same_meta_as_x(t): 84 self.assertEqual(t.size(), x.size()) 85 self.assertEqual(t.stride(), x.stride()) 86 self.assertEqual(t.dtype, x.dtype) 87 self.assertEqual(t.device, x.device) 88 89 writer.write_tensor("x", x) 90 91 with load_tensor_reader(loc): 92 x2 = torch.ops.debugprims.load_tensor.default( 93 "x", (4,), (1,), dtype=torch.float32, device=device 94 ) 95 self.assertEqual(x, x2) 96 x3 = torch.ops.debugprims.load_tensor.default( 97 "x", (4,), (1,), dtype=torch.float32, device=device 98 ) 99 self.assertEqual(x, x3) 100 # Must not alias! 101 self.assertNotEqual( 102 StorageWeakRef(x.untyped_storage()), 103 StorageWeakRef(x2.untyped_storage()), 104 ) 105 self.assertNotEqual( 106 StorageWeakRef(x2.untyped_storage()), 107 StorageWeakRef(x3.untyped_storage()), 108 ) 109 110 # Check fake tensor mode works too 111 with FakeTensorMode(): 112 x4 = torch.ops.debugprims.load_tensor.default( 113 "x", (4,), (1,), dtype=torch.float32, device=device 114 ) 115 self.assertIsInstance(x4, FakeTensor) 116 same_meta_as_x(x4) 117 118 # Check fp64 works 119 x5 = torch.ops.debugprims.load_tensor.default( 120 "x", (4,), (1,), dtype=torch.float64, device=device 121 ) 122 self.assertEqual(x5.float(), x) 123 self.assertEqual(x5.dtype, torch.float64) 124 125 x6 = torch.ops.debugprims.load_tensor.default( 126 "x", (4,), (1,), dtype=torch.float32, device=device 127 ) 128 same_meta_as_x(x6) 129 130 131instantiate_device_type_tests(TestContentStore, globals()) 132 133 134if __name__ == "__main__": 135 run_tests() 136