1 #include <c10/core/impl/COW.h>
2 #include <c10/core/impl/COWDeleter.h>
3
4 #include <c10/core/CPUAllocator.h>
5 #include <c10/core/StorageImpl.h>
6
7 #include <gmock/gmock.h>
8 #include <gtest/gtest.h>
9
10 #include <cstddef>
11 #include <memory>
12
13 // NOLINTBEGIN(clang-analyzer-cplusplus*)
14 namespace c10::impl {
15 namespace {
16
17 class DeleteTracker {
18 public:
DeleteTracker(int & delete_count)19 explicit DeleteTracker(int& delete_count) : delete_count_(delete_count) {}
~DeleteTracker()20 ~DeleteTracker() {
21 ++delete_count_;
22 }
23
24 private:
25 int& delete_count_;
26 };
27
28 class ContextTest : public testing::Test {
29 protected:
delete_count() const30 auto delete_count() const -> int {
31 return delete_count_;
32 }
new_delete_tracker()33 auto new_delete_tracker() -> std::unique_ptr<void, DeleterFnPtr> {
34 return {new DeleteTracker(delete_count_), +[](void* ptr) {
35 delete static_cast<DeleteTracker*>(ptr);
36 }};
37 }
38
39 private:
40 int delete_count_ = 0;
41 };
42
TEST_F(ContextTest,Basic)43 TEST_F(ContextTest, Basic) {
44 auto& context = *new cow::COWDeleterContext(new_delete_tracker());
45 ASSERT_THAT(delete_count(), testing::Eq(0));
46
47 context.increment_refcount();
48
49 {
50 // This is in a sub-scope because this call to decrement_refcount
51 // is expected to give us a shared lock.
52 auto result = context.decrement_refcount();
53 ASSERT_THAT(
54 std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
55 result),
56 testing::IsTrue());
57 ASSERT_THAT(delete_count(), testing::Eq(0));
58 }
59
60 {
61 auto result = context.decrement_refcount();
62 ASSERT_THAT(
63 std::holds_alternative<cow::COWDeleterContext::LastReference>(result),
64 testing::IsTrue());
65 // Result holds the DeleteTracker.
66 ASSERT_THAT(delete_count(), testing::Eq(0));
67 }
68
69 // When result is deleted, the DeleteTracker is also deleted.
70 ASSERT_THAT(delete_count(), testing::Eq(1));
71 }
72
TEST_F(ContextTest,cow_deleter)73 TEST_F(ContextTest, cow_deleter) {
74 // This is effectively the same thing as decrement_refcount() above.
75 auto& context = *new cow::COWDeleterContext(new_delete_tracker());
76 ASSERT_THAT(delete_count(), testing::Eq(0));
77
78 cow::cow_deleter(&context);
79 ASSERT_THAT(delete_count(), testing::Eq(1));
80 }
81
82 MATCHER(is_copy_on_write, "") {
83 const c10::StorageImpl& storage = std::ref(arg);
84 return cow::is_cow_data_ptr(storage.data_ptr());
85 }
86
TEST(lazy_clone_storage_test,no_context)87 TEST(lazy_clone_storage_test, no_context) {
88 StorageImpl original_storage(
89 {}, /*size_bytes=*/7, GetDefaultCPUAllocator(), /*resizable=*/false);
90 ASSERT_THAT(original_storage, testing::Not(is_copy_on_write()));
91 ASSERT_TRUE(cow::has_simple_data_ptr(original_storage));
92
93 intrusive_ptr<StorageImpl> new_storage =
94 cow::lazy_clone_storage(original_storage);
95 ASSERT_THAT(new_storage.get(), testing::NotNull());
96
97 // The original storage was modified in-place to now hold a copy on
98 // write context.
99 ASSERT_THAT(original_storage, is_copy_on_write());
100
101 // The result is a different storage impl.
102 ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
103 // But it is also copy-on-write.
104 ASSERT_THAT(*new_storage, is_copy_on_write());
105 // But they share the same data!
106 ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
107 }
108
109 struct MyDeleterContext {
MyDeleterContextc10::impl::__anonfc98a2330111::MyDeleterContext110 MyDeleterContext(void* bytes) : bytes(bytes) {}
111
~MyDeleterContextc10::impl::__anonfc98a2330111::MyDeleterContext112 ~MyDeleterContext() {
113 delete[] static_cast<std::byte*>(bytes);
114 }
115
116 void* bytes;
117 };
118
my_deleter(void * ctx)119 void my_deleter(void* ctx) {
120 delete static_cast<MyDeleterContext*>(ctx);
121 }
122
TEST(lazy_clone_storage_test,different_context)123 TEST(lazy_clone_storage_test, different_context) {
124 void* bytes = new std::byte[5];
125 StorageImpl storage(
126 {},
127 /*size_bytes=*/5,
128 at::DataPtr(
129 /*data=*/bytes,
130 /*ctx=*/new MyDeleterContext(bytes),
131 /*ctx_deleter=*/my_deleter,
132 /*device=*/Device(Device::Type::CPU)),
133 /*allocator=*/nullptr,
134 /*resizable=*/false);
135
136 // We can't handle an arbitrary context.
137 ASSERT_THAT(cow::lazy_clone_storage(storage), testing::IsNull());
138 }
139
TEST(lazy_clone_storage_test,already_copy_on_write)140 TEST(lazy_clone_storage_test, already_copy_on_write) {
141 std::unique_ptr<void, DeleterFnPtr> data(
142 new std::byte[5],
143 +[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
144 void* data_ptr = data.get();
145 StorageImpl original_storage(
146 {},
147 /*size_bytes=*/5,
148 at::DataPtr(
149 /*data=*/data_ptr,
150 /*ctx=*/new cow::COWDeleterContext(std::move(data)),
151 cow::cow_deleter,
152 Device(Device::Type::CPU)),
153 /*allocator=*/nullptr,
154 /*resizable=*/false);
155
156 ASSERT_THAT(original_storage, is_copy_on_write());
157
158 intrusive_ptr<StorageImpl> new_storage =
159 cow::lazy_clone_storage(original_storage);
160 ASSERT_THAT(new_storage.get(), testing::NotNull());
161
162 // The result is a different storage.
163 ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
164 // But it is also copy-on-write.
165 ASSERT_THAT(*new_storage, is_copy_on_write());
166 // But they share the same data!
167 ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
168 }
169
TEST(materialize_test,not_copy_on_write_context)170 TEST(materialize_test, not_copy_on_write_context) {
171 StorageImpl storage(
172 {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
173 ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
174
175 void const* original_data = storage.data();
176
177 // Nothing to materialize.
178 ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
179 }
180
TEST(materialize_test,copy_on_write_single_reference)181 TEST(materialize_test, copy_on_write_single_reference) {
182 // A copy-on-write storage with only a single reference can just
183 // drop the copy-on-write context upon materialization.
184 std::unique_ptr<void, DeleterFnPtr> data(
185 new std::byte[4],
186 +[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
187 void* data_ptr = data.get();
188 StorageImpl storage(
189 {},
190 /*size_bytes=*/4,
191 at::DataPtr(
192 /*data=*/data_ptr,
193 /*ctx=*/new cow::COWDeleterContext(std::move(data)),
194 cow::cow_deleter,
195 Device(Device::Type::CPU)),
196 /*allocator=*/nullptr,
197 /*resizable=*/false);
198
199 ASSERT_THAT(storage, is_copy_on_write());
200
201 ASSERT_THAT(storage.data(), testing::Eq(data_ptr));
202
203 void const* original_data = storage.data();
204
205 // Materializes storage. Only reference, so no new allocation.
206 ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
207 // But it is no longer copy-on-write.
208 ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
209 }
210
buffers_are_equal(const void * a,const void * b,size_t nbytes)211 bool buffers_are_equal(const void* a, const void* b, size_t nbytes) {
212 const char* a_ = static_cast<const char*>(a);
213 const char* b_ = static_cast<const char*>(b);
214
215 for (size_t idx = 0; idx < nbytes; idx++) {
216 if (a_[idx] != b_[idx]) {
217 return false;
218 }
219 }
220 return true;
221 }
222
TEST(materialize_test,copy_on_write)223 TEST(materialize_test, copy_on_write) {
224 StorageImpl original_storage(
225 {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
226 std::memcpy(original_storage.mutable_data(), "abcd", 4);
227 void const* original_data = original_storage.data();
228
229 auto new_storage = cow::lazy_clone_storage(original_storage);
230 ASSERT_THAT(new_storage, testing::NotNull());
231
232 auto context = new_storage->data_ptr().cast_context<cow::COWDeleterContext>(
233 cow::cow_deleter);
234 ASSERT_THAT(context, testing::NotNull());
235
236 // Materialized storage has new copy of data.
237 ASSERT_THAT(new_storage->mutable_data(), testing::Ne(original_data));
238
239 // But the original storage still has the original copy.
240 ASSERT_THAT(original_storage.data(), testing::Eq(original_data));
241
242 // And their data is the same
243 ASSERT_TRUE(new_storage->nbytes() == original_storage.nbytes());
244 ASSERT_TRUE(buffers_are_equal(
245 new_storage->data(), original_storage.data(), new_storage->nbytes()));
246 }
247
248 } // namespace
249 } // namespace c10::impl
250 // NOLINTEND(clang-analyzer-cplusplus*)
251