xref: /aosp_15_r20/external/pytorch/test/test_content_store.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: pt2"]
2
3import tempfile
4import unittest
5
6import torch
7from torch._prims.debug_prims import load_tensor_reader
8from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
9from torch.multiprocessing.reductions import StorageWeakRef
10from torch.testing._internal.common_device_type import instantiate_device_type_tests
11from torch.testing._internal.common_utils import (
12    IS_WINDOWS,
13    run_tests,
14    skipIfRocm,
15    TestCase,
16)
17from torch.utils._content_store import (
18    ContentStoreReader,
19    ContentStoreWriter,
20    hash_storage,
21)
22
23
24@unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows")
25class TestContentStore(TestCase):
26    def test_basic(self, device):
27        # setup test data
28        x = torch.randn(4, device=device)
29        y = torch.randn(6, device=device)
30        z = x.view(2, 2)
31        # start writing
32        with tempfile.TemporaryDirectory() as loc:
33            writer = ContentStoreWriter(loc)
34            writer.write_tensor("x", x)
35            writer.write_tensor("y", y)
36            writer.write_tensor("z", z)
37            # do some mutation that is VC UNTRACKED
38            x.data.add_(1)
39            writer.write_tensor("x2", x)
40            writer.write_tensor("y2", y)
41            writer.write_tensor("z2", z)
42            del writer
43
44            reader = ContentStoreReader(loc)
45            n_x = reader.read_tensor("x")
46            n_y = reader.read_tensor("y")
47            n_z = reader.read_tensor("z")
48            self.assertEqual(n_x + 1, x)
49            self.assertEqual(n_y, y)
50            self.assertEqual(n_z + 1, z)
51            self.assertEqual(
52                StorageWeakRef(n_x.untyped_storage()),
53                StorageWeakRef(n_z.untyped_storage()),
54            )
55            n_x2 = reader.read_tensor("x2")
56            n_y2 = reader.read_tensor("y2")
57            n_z2 = reader.read_tensor("z2")
58            self.assertEqual(n_x2, x)
59            self.assertEqual(n_y2, y)
60            self.assertEqual(n_z2, z)
61            self.assertEqual(
62                StorageWeakRef(n_y2.untyped_storage()),
63                StorageWeakRef(n_y.untyped_storage()),
64            )
65
66    def test_scalar(self, device):
67        # Should not raise an error
68        hash_storage(torch.tensor(2, device=device).untyped_storage())
69
70    @torch._dynamo.config.patch(cache_size_limit=1)
71    def test_repeated_hash(self, device):
72        # Test that repeated hashing doesn't trigger a recompile in dynamo
73        # If it does, we will execute prims.xor_sum in eager which fails
74        for _ in range(4):
75            hash_storage(torch.tensor(2, device=device).untyped_storage())
76
77    @skipIfRocm
78    def test_load_tensor(self, device):
79        with tempfile.TemporaryDirectory() as loc:
80            writer = ContentStoreWriter(loc)
81            x = torch.randn(4, device=device)
82
83            def same_meta_as_x(t):
84                self.assertEqual(t.size(), x.size())
85                self.assertEqual(t.stride(), x.stride())
86                self.assertEqual(t.dtype, x.dtype)
87                self.assertEqual(t.device, x.device)
88
89            writer.write_tensor("x", x)
90
91            with load_tensor_reader(loc):
92                x2 = torch.ops.debugprims.load_tensor.default(
93                    "x", (4,), (1,), dtype=torch.float32, device=device
94                )
95                self.assertEqual(x, x2)
96                x3 = torch.ops.debugprims.load_tensor.default(
97                    "x", (4,), (1,), dtype=torch.float32, device=device
98                )
99                self.assertEqual(x, x3)
100                # Must not alias!
101                self.assertNotEqual(
102                    StorageWeakRef(x.untyped_storage()),
103                    StorageWeakRef(x2.untyped_storage()),
104                )
105                self.assertNotEqual(
106                    StorageWeakRef(x2.untyped_storage()),
107                    StorageWeakRef(x3.untyped_storage()),
108                )
109
110                # Check fake tensor mode works too
111                with FakeTensorMode():
112                    x4 = torch.ops.debugprims.load_tensor.default(
113                        "x", (4,), (1,), dtype=torch.float32, device=device
114                    )
115                    self.assertIsInstance(x4, FakeTensor)
116                    same_meta_as_x(x4)
117
118                # Check fp64 works
119                x5 = torch.ops.debugprims.load_tensor.default(
120                    "x", (4,), (1,), dtype=torch.float64, device=device
121                )
122                self.assertEqual(x5.float(), x)
123                self.assertEqual(x5.dtype, torch.float64)
124
125        x6 = torch.ops.debugprims.load_tensor.default(
126            "x", (4,), (1,), dtype=torch.float32, device=device
127        )
128        same_meta_as_x(x6)
129
130
131instantiate_device_type_tests(TestContentStore, globals())
132
133
134if __name__ == "__main__":
135    run_tests()
136