xref: /aosp_15_r20/external/pytorch/c10/core/impl/COWDeleter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/COWDeleter.h>
2 #include <c10/util/Exception.h>
3 #include <mutex>
4 
5 namespace c10::impl {
6 
cow_deleter(void * ctx)7 void cow::cow_deleter(void* ctx) {
8   static_cast<cow::COWDeleterContext*>(ctx)->decrement_refcount();
9 }
10 
COWDeleterContext(std::unique_ptr<void,DeleterFnPtr> data)11 cow::COWDeleterContext::COWDeleterContext(
12     std::unique_ptr<void, DeleterFnPtr> data)
13     : data_(std::move(data)) {
14   // We never wrap a COWDeleterContext.
15   TORCH_INTERNAL_ASSERT(data_.get_deleter() != cow::cow_deleter);
16 }
17 
increment_refcount()18 auto cow::COWDeleterContext::increment_refcount() -> void {
19   auto refcount = ++refcount_;
20   TORCH_INTERNAL_ASSERT(refcount > 1);
21 }
22 
decrement_refcount()23 auto cow::COWDeleterContext::decrement_refcount()
24     -> std::variant<NotLastReference, LastReference> {
25   auto refcount = --refcount_;
26   TORCH_INTERNAL_ASSERT(refcount >= 0, refcount);
27   if (refcount == 0) {
28     std::unique_lock lock(mutex_);
29     auto result = std::move(data_);
30     lock.unlock();
31     delete this;
32     return {std::move(result)};
33   }
34 
35   return std::shared_lock(mutex_);
36 }
37 
~COWDeleterContext()38 cow::COWDeleterContext::~COWDeleterContext() {
39   TORCH_INTERNAL_ASSERT(refcount_ == 0);
40 }
41 
42 } // namespace c10::impl
43