xref: /aosp_15_r20/external/pytorch/c10/test/core/impl/cow_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/COW.h>
2 #include <c10/core/impl/COWDeleter.h>
3 
4 #include <c10/core/CPUAllocator.h>
5 #include <c10/core/StorageImpl.h>
6 
7 #include <gmock/gmock.h>
8 #include <gtest/gtest.h>
9 
10 #include <cstddef>
11 #include <memory>
12 
13 // NOLINTBEGIN(clang-analyzer-cplusplus*)
14 namespace c10::impl {
15 namespace {
16 
17 class DeleteTracker {
18  public:
DeleteTracker(int & delete_count)19   explicit DeleteTracker(int& delete_count) : delete_count_(delete_count) {}
~DeleteTracker()20   ~DeleteTracker() {
21     ++delete_count_;
22   }
23 
24  private:
25   int& delete_count_;
26 };
27 
28 class ContextTest : public testing::Test {
29  protected:
delete_count() const30   auto delete_count() const -> int {
31     return delete_count_;
32   }
new_delete_tracker()33   auto new_delete_tracker() -> std::unique_ptr<void, DeleterFnPtr> {
34     return {new DeleteTracker(delete_count_), +[](void* ptr) {
35               delete static_cast<DeleteTracker*>(ptr);
36             }};
37   }
38 
39  private:
40   int delete_count_ = 0;
41 };
42 
TEST_F(ContextTest,Basic)43 TEST_F(ContextTest, Basic) {
44   auto& context = *new cow::COWDeleterContext(new_delete_tracker());
45   ASSERT_THAT(delete_count(), testing::Eq(0));
46 
47   context.increment_refcount();
48 
49   {
50     // This is in a sub-scope because this call to decrement_refcount
51     // is expected to give us a shared lock.
52     auto result = context.decrement_refcount();
53     ASSERT_THAT(
54         std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
55             result),
56         testing::IsTrue());
57     ASSERT_THAT(delete_count(), testing::Eq(0));
58   }
59 
60   {
61     auto result = context.decrement_refcount();
62     ASSERT_THAT(
63         std::holds_alternative<cow::COWDeleterContext::LastReference>(result),
64         testing::IsTrue());
65     // Result holds the DeleteTracker.
66     ASSERT_THAT(delete_count(), testing::Eq(0));
67   }
68 
69   // When result is deleted, the DeleteTracker is also deleted.
70   ASSERT_THAT(delete_count(), testing::Eq(1));
71 }
72 
TEST_F(ContextTest,cow_deleter)73 TEST_F(ContextTest, cow_deleter) {
74   // This is effectively the same thing as decrement_refcount() above.
75   auto& context = *new cow::COWDeleterContext(new_delete_tracker());
76   ASSERT_THAT(delete_count(), testing::Eq(0));
77 
78   cow::cow_deleter(&context);
79   ASSERT_THAT(delete_count(), testing::Eq(1));
80 }
81 
82 MATCHER(is_copy_on_write, "") {
83   const c10::StorageImpl& storage = std::ref(arg);
84   return cow::is_cow_data_ptr(storage.data_ptr());
85 }
86 
TEST(lazy_clone_storage_test,no_context)87 TEST(lazy_clone_storage_test, no_context) {
88   StorageImpl original_storage(
89       {}, /*size_bytes=*/7, GetDefaultCPUAllocator(), /*resizable=*/false);
90   ASSERT_THAT(original_storage, testing::Not(is_copy_on_write()));
91   ASSERT_TRUE(cow::has_simple_data_ptr(original_storage));
92 
93   intrusive_ptr<StorageImpl> new_storage =
94       cow::lazy_clone_storage(original_storage);
95   ASSERT_THAT(new_storage.get(), testing::NotNull());
96 
97   // The original storage was modified in-place to now hold a copy on
98   // write context.
99   ASSERT_THAT(original_storage, is_copy_on_write());
100 
101   // The result is a different storage impl.
102   ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
103   // But it is also copy-on-write.
104   ASSERT_THAT(*new_storage, is_copy_on_write());
105   // But they share the same data!
106   ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
107 }
108 
109 struct MyDeleterContext {
MyDeleterContextc10::impl::__anonfc98a2330111::MyDeleterContext110   MyDeleterContext(void* bytes) : bytes(bytes) {}
111 
~MyDeleterContextc10::impl::__anonfc98a2330111::MyDeleterContext112   ~MyDeleterContext() {
113     delete[] static_cast<std::byte*>(bytes);
114   }
115 
116   void* bytes;
117 };
118 
my_deleter(void * ctx)119 void my_deleter(void* ctx) {
120   delete static_cast<MyDeleterContext*>(ctx);
121 }
122 
TEST(lazy_clone_storage_test,different_context)123 TEST(lazy_clone_storage_test, different_context) {
124   void* bytes = new std::byte[5];
125   StorageImpl storage(
126       {},
127       /*size_bytes=*/5,
128       at::DataPtr(
129           /*data=*/bytes,
130           /*ctx=*/new MyDeleterContext(bytes),
131           /*ctx_deleter=*/my_deleter,
132           /*device=*/Device(Device::Type::CPU)),
133       /*allocator=*/nullptr,
134       /*resizable=*/false);
135 
136   // We can't handle an arbitrary context.
137   ASSERT_THAT(cow::lazy_clone_storage(storage), testing::IsNull());
138 }
139 
TEST(lazy_clone_storage_test,already_copy_on_write)140 TEST(lazy_clone_storage_test, already_copy_on_write) {
141   std::unique_ptr<void, DeleterFnPtr> data(
142       new std::byte[5],
143       +[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
144   void* data_ptr = data.get();
145   StorageImpl original_storage(
146       {},
147       /*size_bytes=*/5,
148       at::DataPtr(
149           /*data=*/data_ptr,
150           /*ctx=*/new cow::COWDeleterContext(std::move(data)),
151           cow::cow_deleter,
152           Device(Device::Type::CPU)),
153       /*allocator=*/nullptr,
154       /*resizable=*/false);
155 
156   ASSERT_THAT(original_storage, is_copy_on_write());
157 
158   intrusive_ptr<StorageImpl> new_storage =
159       cow::lazy_clone_storage(original_storage);
160   ASSERT_THAT(new_storage.get(), testing::NotNull());
161 
162   // The result is a different storage.
163   ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
164   // But it is also copy-on-write.
165   ASSERT_THAT(*new_storage, is_copy_on_write());
166   // But they share the same data!
167   ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
168 }
169 
TEST(materialize_test,not_copy_on_write_context)170 TEST(materialize_test, not_copy_on_write_context) {
171   StorageImpl storage(
172       {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
173   ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
174 
175   void const* original_data = storage.data();
176 
177   // Nothing to materialize.
178   ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
179 }
180 
TEST(materialize_test,copy_on_write_single_reference)181 TEST(materialize_test, copy_on_write_single_reference) {
182   // A copy-on-write storage with only a single reference can just
183   // drop the copy-on-write context upon materialization.
184   std::unique_ptr<void, DeleterFnPtr> data(
185       new std::byte[4],
186       +[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
187   void* data_ptr = data.get();
188   StorageImpl storage(
189       {},
190       /*size_bytes=*/4,
191       at::DataPtr(
192           /*data=*/data_ptr,
193           /*ctx=*/new cow::COWDeleterContext(std::move(data)),
194           cow::cow_deleter,
195           Device(Device::Type::CPU)),
196       /*allocator=*/nullptr,
197       /*resizable=*/false);
198 
199   ASSERT_THAT(storage, is_copy_on_write());
200 
201   ASSERT_THAT(storage.data(), testing::Eq(data_ptr));
202 
203   void const* original_data = storage.data();
204 
205   // Materializes storage. Only reference, so no new allocation.
206   ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
207   // But it is no longer copy-on-write.
208   ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
209 }
210 
buffers_are_equal(const void * a,const void * b,size_t nbytes)211 bool buffers_are_equal(const void* a, const void* b, size_t nbytes) {
212   const char* a_ = static_cast<const char*>(a);
213   const char* b_ = static_cast<const char*>(b);
214 
215   for (size_t idx = 0; idx < nbytes; idx++) {
216     if (a_[idx] != b_[idx]) {
217       return false;
218     }
219   }
220   return true;
221 }
222 
TEST(materialize_test,copy_on_write)223 TEST(materialize_test, copy_on_write) {
224   StorageImpl original_storage(
225       {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
226   std::memcpy(original_storage.mutable_data(), "abcd", 4);
227   void const* original_data = original_storage.data();
228 
229   auto new_storage = cow::lazy_clone_storage(original_storage);
230   ASSERT_THAT(new_storage, testing::NotNull());
231 
232   auto context = new_storage->data_ptr().cast_context<cow::COWDeleterContext>(
233       cow::cow_deleter);
234   ASSERT_THAT(context, testing::NotNull());
235 
236   // Materialized storage has new copy of data.
237   ASSERT_THAT(new_storage->mutable_data(), testing::Ne(original_data));
238 
239   // But the original storage still has the original copy.
240   ASSERT_THAT(original_storage.data(), testing::Eq(original_data));
241 
242   // And their data is the same
243   ASSERT_TRUE(new_storage->nbytes() == original_storage.nbytes());
244   ASSERT_TRUE(buffers_are_equal(
245       new_storage->data(), original_storage.data(), new_storage->nbytes()));
246 }
247 
248 } // namespace
249 } // namespace c10::impl
250 // NOLINTEND(clang-analyzer-cplusplus*)
251