xref: /aosp_15_r20/external/pytorch/c10/core/impl/COW.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/COW.h>
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/core/StorageImpl.h>
5 #include <c10/core/alignment.h>
6 #include <c10/core/impl/COWDeleter.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/ParallelGuard.h>
9 #include <c10/util/UniqueVoidPtr.h>
10 
11 #include <memory>
12 #include <optional>
13 
14 namespace c10::impl::cow {
15 
16 namespace {
17 
18 // Wraps a DataPtr with a copy-on-write DataPtr.
make_data_ptr(at::DataPtr const & data_ptr,cow::COWDeleterContext & ctx)19 at::DataPtr make_data_ptr(
20     at::DataPtr const& data_ptr,
21     cow::COWDeleterContext& ctx) {
22   return at::DataPtr(data_ptr.get(), &ctx, cow::cow_deleter, data_ptr.device());
23 }
24 
25 /// Copies a copy-on-write DataPtr.
copy_data_ptr(at::DataPtr const & data_ptr)26 at::DataPtr copy_data_ptr(at::DataPtr const& data_ptr) {
27   auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
28   TORCH_INTERNAL_ASSERT(ctx != nullptr);
29   ctx->increment_refcount();
30   return make_data_ptr(data_ptr, *ctx);
31 }
32 
33 } // namespace
34 
has_simple_data_ptr(const c10::StorageImpl & storage)35 bool has_simple_data_ptr(const c10::StorageImpl& storage) {
36   const c10::DataPtr& data_ptr = storage.data_ptr();
37   const void* ctx = data_ptr.get_context();
38   const void* data = data_ptr.get();
39   const c10::Allocator* allocator = storage.allocator();
40   if (allocator != nullptr) {
41     return allocator->is_simple_data_ptr(data_ptr);
42   } else {
43     return ctx == data;
44   }
45 }
46 
is_cow_data_ptr(const c10::DataPtr & data_ptr)47 bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
48   return (void*)data_ptr.get_deleter() == (void*)&cow::cow_deleter;
49 }
50 
lazy_clone_storage(StorageImpl & storage)51 c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
52   const at::DataPtr& data_ptr = storage.data_ptr();
53 
54   // There are three possible circumstances:
55   //
56   // 1) The storage has a normal data pointer with no out of the ordinary
57   //    context. In this case we know that there are no blind aliases to the
58   //    storage impl: they all will be public aliases and the user is expected
59   //    to synchronize manually.
60   //
61   //    No locking is required in this case.
62   //
63   // 2) The storage already has a copy on write context. There
64   //    is a potential race condition with a blind alias (i.e. an
65   //    alias that the user is not required to synchronize
66   //    with). Because our input storage is bound to a live reference
67   //    to the data, we know that it isn't going away. A blind alias
68   //    could be copying from it right now, but we will grab the
69   //    context's mutex to protect us.
70   //
71   //    We do not need to lock in this case either, because we're just
72   //    wrapping a context that we know isn't going away.
73   //
74   // 3) The storage has a context that is not the copy on write
75   //    context. This is not supported, so we just return null.
76   //
77   //    No locking is required in this case.
78 
79   std::optional<DataPtr> new_data_ptr; // must be set below
80 
81   if (has_simple_data_ptr(storage)) {
82     // Case 1) We have a simple data pointer: wrap it.
83     std::unique_ptr<void, DeleterFnPtr> original_ctx =
84         storage._mutable_data_ptr_no_checks().move_context();
85 
86     // Save this for the result.
87     new_data_ptr = make_data_ptr(
88         data_ptr, *new cow::COWDeleterContext(std::move(original_ctx)));
89 
90     // Update this storage to the new copy on write context.
91     storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr));
92   } else if (is_cow_data_ptr(data_ptr)) {
93     // Case 2): there is already a copy on write context. Just return a
94     // new storage impl.
95     new_data_ptr = copy_data_ptr(data_ptr);
96   } else {
97     // Case 3) There is a context and it's not copy-on-write. Nothing
98     // we can do here.
99     return nullptr;
100   }
101 
102   TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
103 
104   return make_storage_impl(
105       StorageImpl::use_byte_size_t(),
106       storage.sym_nbytes(),
107       *std::move(new_data_ptr),
108       storage.allocator(),
109       storage.resizable(),
110       storage.device_type());
111 }
112 
materialize_cow_storage(StorageImpl & storage)113 C10_API void materialize_cow_storage(StorageImpl& storage) {
114   TORCH_INTERNAL_ASSERT(
115       !c10::ParallelGuard::is_enabled(),
116       "Materializing a storage in the loop function of at::parallel_for is forbidden");
117   const at::DataPtr& data_ptr = storage.data_ptr();
118 
119   auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
120   TORCH_INTERNAL_ASSERT(ctx != nullptr);
121 
122   auto result = ctx->decrement_refcount();
123 
124   // This must be set by each branch below.
125   std::optional<DataPtr> new_data_ptr;
126 
127   if (std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
128     // This is the only reference to the data. If there were any racing writes,
129     // the context ensured they finished before giving us the result.
130     std::unique_ptr<void, DeleterFnPtr> data =
131         std::get<cow::COWDeleterContext::LastReference>(std::move(result));
132     TORCH_INTERNAL_ASSERT(data.get() == data_ptr.get());
133     new_data_ptr = DataPtr(
134         data.release(), data_ptr.get(), data.get_deleter(), data_ptr.device());
135   } else {
136     TORCH_INTERNAL_ASSERT(
137         std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
138             result));
139     // We don't need to consume the result, it's just a shared lock ensuring
140     // that the data will remain while we copy it.
141     new_data_ptr = storage.allocator()->clone(data_ptr.get(), storage.nbytes());
142   }
143 
144   TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
145   DataPtr old_data_ptr =
146       storage.set_data_ptr_no_materialize_cow(*std::move(new_data_ptr));
147   // The refcount of the context was already decremented above. Release the
148   // reference to the context so the refcount doesn't get decremented again
149   old_data_ptr.release_context();
150 }
151 
152 } // namespace c10::impl::cow
153