1 #include <c10/core/impl/COW.h>
2
3 #include <c10/core/Allocator.h>
4 #include <c10/core/StorageImpl.h>
5 #include <c10/core/alignment.h>
6 #include <c10/core/impl/COWDeleter.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/ParallelGuard.h>
9 #include <c10/util/UniqueVoidPtr.h>
10
11 #include <memory>
12 #include <optional>
13
14 namespace c10::impl::cow {
15
16 namespace {
17
18 // Wraps a DataPtr with a copy-on-write DataPtr.
make_data_ptr(at::DataPtr const & data_ptr,cow::COWDeleterContext & ctx)19 at::DataPtr make_data_ptr(
20 at::DataPtr const& data_ptr,
21 cow::COWDeleterContext& ctx) {
22 return at::DataPtr(data_ptr.get(), &ctx, cow::cow_deleter, data_ptr.device());
23 }
24
25 /// Copies a copy-on-write DataPtr.
copy_data_ptr(at::DataPtr const & data_ptr)26 at::DataPtr copy_data_ptr(at::DataPtr const& data_ptr) {
27 auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
28 TORCH_INTERNAL_ASSERT(ctx != nullptr);
29 ctx->increment_refcount();
30 return make_data_ptr(data_ptr, *ctx);
31 }
32
33 } // namespace
34
has_simple_data_ptr(const c10::StorageImpl & storage)35 bool has_simple_data_ptr(const c10::StorageImpl& storage) {
36 const c10::DataPtr& data_ptr = storage.data_ptr();
37 const void* ctx = data_ptr.get_context();
38 const void* data = data_ptr.get();
39 const c10::Allocator* allocator = storage.allocator();
40 if (allocator != nullptr) {
41 return allocator->is_simple_data_ptr(data_ptr);
42 } else {
43 return ctx == data;
44 }
45 }
46
is_cow_data_ptr(const c10::DataPtr & data_ptr)47 bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
48 return (void*)data_ptr.get_deleter() == (void*)&cow::cow_deleter;
49 }
50
lazy_clone_storage(StorageImpl & storage)51 c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
52 const at::DataPtr& data_ptr = storage.data_ptr();
53
54 // There are three possible circumstances:
55 //
56 // 1) The storage has a normal data pointer with no out of the ordinary
57 // context. In this case we know that there are no blind aliases to the
58 // storage impl: they all will be public aliases and the user is expected
59 // to synchronize manually.
60 //
61 // No locking is required in this case.
62 //
63 // 2) The storage already has a copy on write context. There
64 // is a potential race condition with a blind alias (i.e. an
65 // alias that the user is not required to synchronize
66 // with). Because our input storage is bound to a live reference
67 // to the data, we know that it isn't going away. A blind alias
68 // could be copying from it right now, but we will grab the
69 // context's mutex to protect us.
70 //
71 // We do not need to lock in this case either, because we're just
72 // wrapping a context that we know isn't going away.
73 //
74 // 3) The storage has a context that is not the copy on write
75 // context. This is not supported, so we just return null.
76 //
77 // No locking is required in this case.
78
79 std::optional<DataPtr> new_data_ptr; // must be set below
80
81 if (has_simple_data_ptr(storage)) {
82 // Case 1) We have a simple data pointer: wrap it.
83 std::unique_ptr<void, DeleterFnPtr> original_ctx =
84 storage._mutable_data_ptr_no_checks().move_context();
85
86 // Save this for the result.
87 new_data_ptr = make_data_ptr(
88 data_ptr, *new cow::COWDeleterContext(std::move(original_ctx)));
89
90 // Update this storage to the new copy on write context.
91 storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr));
92 } else if (is_cow_data_ptr(data_ptr)) {
93 // Case 2): there is already a copy on write context. Just return a
94 // new storage impl.
95 new_data_ptr = copy_data_ptr(data_ptr);
96 } else {
97 // Case 3) There is a context and it's not copy-on-write. Nothing
98 // we can do here.
99 return nullptr;
100 }
101
102 TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
103
104 return make_storage_impl(
105 StorageImpl::use_byte_size_t(),
106 storage.sym_nbytes(),
107 *std::move(new_data_ptr),
108 storage.allocator(),
109 storage.resizable(),
110 storage.device_type());
111 }
112
materialize_cow_storage(StorageImpl & storage)113 C10_API void materialize_cow_storage(StorageImpl& storage) {
114 TORCH_INTERNAL_ASSERT(
115 !c10::ParallelGuard::is_enabled(),
116 "Materializing a storage in the loop function of at::parallel_for is forbidden");
117 const at::DataPtr& data_ptr = storage.data_ptr();
118
119 auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
120 TORCH_INTERNAL_ASSERT(ctx != nullptr);
121
122 auto result = ctx->decrement_refcount();
123
124 // This must be set by each branch below.
125 std::optional<DataPtr> new_data_ptr;
126
127 if (std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
128 // This is the only reference to the data. If there were any racing writes,
129 // the context ensured they finished before giving us the result.
130 std::unique_ptr<void, DeleterFnPtr> data =
131 std::get<cow::COWDeleterContext::LastReference>(std::move(result));
132 TORCH_INTERNAL_ASSERT(data.get() == data_ptr.get());
133 new_data_ptr = DataPtr(
134 data.release(), data_ptr.get(), data.get_deleter(), data_ptr.device());
135 } else {
136 TORCH_INTERNAL_ASSERT(
137 std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
138 result));
139 // We don't need to consume the result, it's just a shared lock ensuring
140 // that the data will remain while we copy it.
141 new_data_ptr = storage.allocator()->clone(data_ptr.get(), storage.nbytes());
142 }
143
144 TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
145 DataPtr old_data_ptr =
146 storage.set_data_ptr_no_materialize_cow(*std::move(new_data_ptr));
147 // The refcount of the context was already decremented above. Release the
148 // reference to the context so the refcount doesn't get decremented again
149 old_data_ptr.release_context();
150 }
151
152 } // namespace c10::impl::cow
153