xref: /aosp_15_r20/external/pytorch/c10/util/UniqueVoidPtr.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <cstddef>
3*da0073e9SAndroid Build Coastguard Worker #include <memory>
4*da0073e9SAndroid Build Coastguard Worker #include <utility>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Export.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker namespace c10 {
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker using DeleterFnPtr = void (*)(void*);
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker namespace detail {
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker // Does not delete anything
16*da0073e9SAndroid Build Coastguard Worker C10_API void deleteNothing(void*);
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker // A detail::UniqueVoidPtr is an owning smart pointer like unique_ptr, but
19*da0073e9SAndroid Build Coastguard Worker // with three major differences:
20*da0073e9SAndroid Build Coastguard Worker //
21*da0073e9SAndroid Build Coastguard Worker //    1) It is specialized to void
22*da0073e9SAndroid Build Coastguard Worker //
23*da0073e9SAndroid Build Coastguard Worker //    2) It is specialized for a function pointer deleter
24*da0073e9SAndroid Build Coastguard Worker //       void(void* ctx); i.e., the deleter doesn't take a
25*da0073e9SAndroid Build Coastguard Worker //       reference to the data, just to a context pointer
26*da0073e9SAndroid Build Coastguard Worker //       (erased as void*).  In fact, internally, this pointer
27*da0073e9SAndroid Build Coastguard Worker //       is implemented as having an owning reference to
28*da0073e9SAndroid Build Coastguard Worker //       context, and a non-owning reference to data; this is why
29*da0073e9SAndroid Build Coastguard Worker //       you release_context(), not release() (the conventional
30*da0073e9SAndroid Build Coastguard Worker //       API for release() wouldn't give you enough information
31*da0073e9SAndroid Build Coastguard Worker //       to properly dispose of the object later.)
32*da0073e9SAndroid Build Coastguard Worker //
33*da0073e9SAndroid Build Coastguard Worker //    3) The deleter is guaranteed to be called when the unique
34*da0073e9SAndroid Build Coastguard Worker //       pointer is destructed and the context is non-null; this is different
35*da0073e9SAndroid Build Coastguard Worker //       from std::unique_ptr where the deleter is not called if the
36*da0073e9SAndroid Build Coastguard Worker //       data pointer is null.
37*da0073e9SAndroid Build Coastguard Worker //
38*da0073e9SAndroid Build Coastguard Worker // Some of the methods have slightly different types than std::unique_ptr
39*da0073e9SAndroid Build Coastguard Worker // to reflect this.
40*da0073e9SAndroid Build Coastguard Worker //
41*da0073e9SAndroid Build Coastguard Worker class UniqueVoidPtr {
42*da0073e9SAndroid Build Coastguard Worker  private:
43*da0073e9SAndroid Build Coastguard Worker   // Lifetime tied to ctx_
44*da0073e9SAndroid Build Coastguard Worker   void* data_;
45*da0073e9SAndroid Build Coastguard Worker   std::unique_ptr<void, DeleterFnPtr> ctx_;
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker  public:
UniqueVoidPtr()48*da0073e9SAndroid Build Coastguard Worker   UniqueVoidPtr() : data_(nullptr), ctx_(nullptr, &deleteNothing) {}
UniqueVoidPtr(void * data)49*da0073e9SAndroid Build Coastguard Worker   explicit UniqueVoidPtr(void* data)
50*da0073e9SAndroid Build Coastguard Worker       : data_(data), ctx_(nullptr, &deleteNothing) {}
UniqueVoidPtr(void * data,void * ctx,DeleterFnPtr ctx_deleter)51*da0073e9SAndroid Build Coastguard Worker   UniqueVoidPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter)
52*da0073e9SAndroid Build Coastguard Worker       : data_(data), ctx_(ctx, ctx_deleter ? ctx_deleter : &deleteNothing) {}
53*da0073e9SAndroid Build Coastguard Worker   void* operator->() const {
54*da0073e9SAndroid Build Coastguard Worker     return data_;
55*da0073e9SAndroid Build Coastguard Worker   }
clear()56*da0073e9SAndroid Build Coastguard Worker   void clear() {
57*da0073e9SAndroid Build Coastguard Worker     ctx_ = nullptr;
58*da0073e9SAndroid Build Coastguard Worker     data_ = nullptr;
59*da0073e9SAndroid Build Coastguard Worker   }
get()60*da0073e9SAndroid Build Coastguard Worker   void* get() const {
61*da0073e9SAndroid Build Coastguard Worker     return data_;
62*da0073e9SAndroid Build Coastguard Worker   }
get_context()63*da0073e9SAndroid Build Coastguard Worker   void* get_context() const {
64*da0073e9SAndroid Build Coastguard Worker     return ctx_.get();
65*da0073e9SAndroid Build Coastguard Worker   }
release_context()66*da0073e9SAndroid Build Coastguard Worker   void* release_context() {
67*da0073e9SAndroid Build Coastguard Worker     return ctx_.release();
68*da0073e9SAndroid Build Coastguard Worker   }
move_context()69*da0073e9SAndroid Build Coastguard Worker   std::unique_ptr<void, DeleterFnPtr>&& move_context() {
70*da0073e9SAndroid Build Coastguard Worker     return std::move(ctx_);
71*da0073e9SAndroid Build Coastguard Worker   }
compare_exchange_deleter(DeleterFnPtr expected_deleter,DeleterFnPtr new_deleter)72*da0073e9SAndroid Build Coastguard Worker   C10_NODISCARD bool compare_exchange_deleter(
73*da0073e9SAndroid Build Coastguard Worker       DeleterFnPtr expected_deleter,
74*da0073e9SAndroid Build Coastguard Worker       DeleterFnPtr new_deleter) {
75*da0073e9SAndroid Build Coastguard Worker     if (get_deleter() != expected_deleter)
76*da0073e9SAndroid Build Coastguard Worker       return false;
77*da0073e9SAndroid Build Coastguard Worker     ctx_ = std::unique_ptr<void, DeleterFnPtr>(ctx_.release(), new_deleter);
78*da0073e9SAndroid Build Coastguard Worker     return true;
79*da0073e9SAndroid Build Coastguard Worker   }
80*da0073e9SAndroid Build Coastguard Worker 
81*da0073e9SAndroid Build Coastguard Worker   template <typename T>
cast_context(DeleterFnPtr expected_deleter)82*da0073e9SAndroid Build Coastguard Worker   T* cast_context(DeleterFnPtr expected_deleter) const {
83*da0073e9SAndroid Build Coastguard Worker     if (get_deleter() != expected_deleter)
84*da0073e9SAndroid Build Coastguard Worker       return nullptr;
85*da0073e9SAndroid Build Coastguard Worker     return static_cast<T*>(get_context());
86*da0073e9SAndroid Build Coastguard Worker   }
87*da0073e9SAndroid Build Coastguard Worker   operator bool() const {
88*da0073e9SAndroid Build Coastguard Worker     return data_ || ctx_;
89*da0073e9SAndroid Build Coastguard Worker   }
get_deleter()90*da0073e9SAndroid Build Coastguard Worker   DeleterFnPtr get_deleter() const {
91*da0073e9SAndroid Build Coastguard Worker     return ctx_.get_deleter();
92*da0073e9SAndroid Build Coastguard Worker   }
93*da0073e9SAndroid Build Coastguard Worker };
94*da0073e9SAndroid Build Coastguard Worker 
95*da0073e9SAndroid Build Coastguard Worker // Note [How UniqueVoidPtr is implemented]
96*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
97*da0073e9SAndroid Build Coastguard Worker // UniqueVoidPtr solves a common problem for allocators of tensor data, which
98*da0073e9SAndroid Build Coastguard Worker // is that the data pointer (e.g., float*) which you are interested in, is not
99*da0073e9SAndroid Build Coastguard Worker // the same as the context pointer (e.g., DLManagedTensor) which you need
100*da0073e9SAndroid Build Coastguard Worker // to actually deallocate the data.  Under a conventional deleter design, you
101*da0073e9SAndroid Build Coastguard Worker // have to store extra context in the deleter itself so that you can actually
102*da0073e9SAndroid Build Coastguard Worker // delete the right thing.  Implementing this with standard C++ is somewhat
103*da0073e9SAndroid Build Coastguard Worker // error-prone: if you use a std::unique_ptr to manage tensors, the deleter will
104*da0073e9SAndroid Build Coastguard Worker // not be called if the data pointer is nullptr, which can cause a leak if the
105*da0073e9SAndroid Build Coastguard Worker // context pointer is non-null (and the deleter is responsible for freeing both
106*da0073e9SAndroid Build Coastguard Worker // the data pointer and the context pointer).
107*da0073e9SAndroid Build Coastguard Worker //
108*da0073e9SAndroid Build Coastguard Worker // So, in our reimplementation of unique_ptr, which just store the context
109*da0073e9SAndroid Build Coastguard Worker // directly in the unique pointer, and attach the deleter to the context
110*da0073e9SAndroid Build Coastguard Worker // pointer itself.  In simple cases, the context pointer is just the pointer
111*da0073e9SAndroid Build Coastguard Worker // itself.
112*da0073e9SAndroid Build Coastguard Worker 
113*da0073e9SAndroid Build Coastguard Worker inline bool operator==(const UniqueVoidPtr& sp, std::nullptr_t) noexcept {
114*da0073e9SAndroid Build Coastguard Worker   return !sp;
115*da0073e9SAndroid Build Coastguard Worker }
116*da0073e9SAndroid Build Coastguard Worker inline bool operator==(std::nullptr_t, const UniqueVoidPtr& sp) noexcept {
117*da0073e9SAndroid Build Coastguard Worker   return !sp;
118*da0073e9SAndroid Build Coastguard Worker }
119*da0073e9SAndroid Build Coastguard Worker inline bool operator!=(const UniqueVoidPtr& sp, std::nullptr_t) noexcept {
120*da0073e9SAndroid Build Coastguard Worker   return sp;
121*da0073e9SAndroid Build Coastguard Worker }
122*da0073e9SAndroid Build Coastguard Worker inline bool operator!=(std::nullptr_t, const UniqueVoidPtr& sp) noexcept {
123*da0073e9SAndroid Build Coastguard Worker   return sp;
124*da0073e9SAndroid Build Coastguard Worker }
125*da0073e9SAndroid Build Coastguard Worker 
126*da0073e9SAndroid Build Coastguard Worker } // namespace detail
127*da0073e9SAndroid Build Coastguard Worker } // namespace c10
128