xref: /aosp_15_r20/external/pytorch/c10/util/intrusive_ptr.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 #include <c10/util/MaybeOwned.h>
5 #include <atomic>
6 #include <climits>
7 #include <memory>
8 #include <type_traits>
9 
10 namespace pybind11 {
11 template <typename, typename...>
12 class class_;
13 }
14 
15 namespace c10 {
16 class intrusive_ptr_target;
17 namespace raw {
18 namespace weak_intrusive_ptr {
19 inline void incref(intrusive_ptr_target* self);
20 }
21 namespace intrusive_ptr {
22 inline void incref(intrusive_ptr_target* self);
23 }
24 
25 // constructor tag used by intrusive_ptr constructors
26 struct DontIncreaseRefcount {};
27 } // namespace raw
28 
29 namespace detail {
30 constexpr uint32_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF;
31 } // namespace detail
32 
33 /**
34  * intrusive_ptr<T> is an alternative to shared_ptr<T> that has better
35  * performance because it does the refcounting intrusively
36  * (i.e. in a member of the object itself).
37  * Your class T needs to inherit from intrusive_ptr_target to allow it to be
38  * used in an intrusive_ptr<T>. Your class's constructor should not allow
39  *`this` to escape to other threads or create an intrusive_ptr from `this`.
40  */
41 
42 // Note [Stack allocated intrusive_ptr_target safety]
43 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
44 // A well known problem with std::enable_shared_from_this is that it
45 // allows you to create a std::shared_ptr from a stack allocated object,
46 // which is totally bogus because the object will die once you return
47 // from the stack.  In intrusive_ptr, we can detect that this has occurred,
48 // because we set the refcount/weakcount of objects which inherit from
49 // intrusive_ptr_target to zero, *unless* we can prove that the object
50 // was dynamically allocated (e.g., via make_intrusive).
51 //
52 // Thus, whenever you transmute a T* into a intrusive_ptr<T>, we check
53 // and make sure that the refcount isn't zero (or, a more subtle
54 // test for weak_intrusive_ptr<T>, for which the refcount may validly
55 // be zero, but the weak refcount better not be zero), because that
56 // tells us if the object was allocated by us.  If it wasn't, no
57 // intrusive_ptr for you!
58 
59 // NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor)
60 class C10_API intrusive_ptr_target {
61   // Note [Weak references for intrusive refcounting]
62   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
63   // Here's the scheme:
64   //
65   //  - refcount == number of strong references to the object
66   //    weakcount == number of weak references to the object,
67   //      plus one more if refcount > 0
68   //    An invariant: refcount > 0  =>  weakcount > 0
69   //
70   //  - c10::StorageImpl stays live as long as there are any strong
71   //    or weak pointers to it (weakcount > 0, since strong
72   //    references count as a +1 to weakcount)
73   //
74   //  - finalizers are called and data_ptr is deallocated when refcount == 0
75   //
76   //  - Once refcount == 0, it can never again be > 0 (the transition
77   //    from > 0 to == 0 is monotonic)
78   //
79   //  - When you access c10::StorageImpl via a weak pointer, you must
80   //    atomically increment the use count, if it is greater than 0.
81   //    If it is not, you must report that the storage is dead.
82   //
83   mutable std::atomic<uint32_t> refcount_;
84   mutable std::atomic<uint32_t> weakcount_;
85 
86   template <typename T, typename NullType>
87   friend class intrusive_ptr;
88   friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self);
89 
90   template <typename T, typename NullType>
91   friend class weak_intrusive_ptr;
92   friend inline void raw::weak_intrusive_ptr::incref(
93       intrusive_ptr_target* self);
94 
95   template <typename T>
96   friend struct ExclusivelyOwnedTensorTraits;
97 
98  protected:
99   // protected destructor. We never want to destruct intrusive_ptr_target*
100   // directly.
~intrusive_ptr_target()101   virtual ~intrusive_ptr_target() {
102 // Disable -Wterminate and -Wexceptions so we're allowed to use assertions
103 // (i.e. throw exceptions) in a destructor.
104 // We also have to disable -Wunknown-warning-option and -Wpragmas, because
105 // some other compilers don't know about -Wterminate or -Wexceptions and
106 // will show a warning about unknown warning options otherwise.
107 #if defined(_MSC_VER) && !defined(__clang__)
108 #pragma warning(push)
109 #pragma warning( \
110     disable : 4297) // function assumed not to throw an exception but does
111 #else
112 #pragma GCC diagnostic push
113 #pragma GCC diagnostic ignored "-Wpragmas"
114 #pragma GCC diagnostic ignored "-Wunknown-warning-option"
115 #pragma GCC diagnostic ignored "-Wterminate"
116 #pragma GCC diagnostic ignored "-Wexceptions"
117 #endif
118     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
119         // Second condition is there to accommodate
120         // unsafe_adapt_non_heap_allocated: since we are doing our own
121         // deallocation in that case, it is correct for each
122         // expected_decref to have happened (some user code tried to
123         // decref and thus free the object, but it didn't happen right
124         // away) or not (no user code tried to free the object, and
125         // now it's getting destroyed through whatever mechanism the
126         // caller of unsafe_adapt_non_heap_allocated wanted to
127         // use). We choose our reference count such that the count
128         // will not dip below kImpracticallyHugeReferenceCount regardless.
129         refcount_.load() == 0 ||
130             refcount_.load() >= detail::kImpracticallyHugeReferenceCount,
131         "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ",
132         refcount_.load());
133     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
134         // See ~intrusive_ptr for optimization that will frequently result in 1
135         // at destruction time.
136         weakcount_.load() == 1 || weakcount_.load() == 0 ||
137             weakcount_.load() == detail::kImpracticallyHugeReferenceCount - 1 ||
138             weakcount_.load() == detail::kImpracticallyHugeReferenceCount,
139         "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it");
140 #if defined(_MSC_VER) && !defined(__clang__)
141 #pragma warning(pop)
142 #else
143 #pragma GCC diagnostic pop
144 #endif
145   }
146 
intrusive_ptr_target()147   constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {}
148 
149   // intrusive_ptr_target supports copy and move: but refcount and weakcount
150   // don't participate (since they are intrinsic properties of the memory
151   // location)
intrusive_ptr_target(intrusive_ptr_target &&)152   intrusive_ptr_target(intrusive_ptr_target&& /*other*/) noexcept
153       : intrusive_ptr_target() {}
154 
155   intrusive_ptr_target& operator=(intrusive_ptr_target&& /*other*/) noexcept {
156     return *this;
157   }
158 
intrusive_ptr_target(const intrusive_ptr_target &)159   intrusive_ptr_target(const intrusive_ptr_target& /*other*/) noexcept
160       : intrusive_ptr_target() {}
161 
162   intrusive_ptr_target& operator=(
163       const intrusive_ptr_target& /*other*/) noexcept {
164     return *this;
165   }
166 
167  private:
168   /**
169    * This is called when refcount reaches zero.
170    * You can override this to release expensive resources.
171    * There might still be weak references, so your object might not get
172    * destructed yet, but you can assume the object isn't used anymore,
173    * i.e. no more calls to methods or accesses to members (we just can't
174    * destruct it yet because we need the weakcount accessible).
175    *
176    * If there are no weak references (i.e. your class is about to be
177    * destructed), this function WILL NOT be called.
178    */
release_resources()179   virtual void release_resources() {}
180 };
181 
182 namespace detail {
183 template <class TTarget>
184 struct intrusive_target_default_null_type final {
singletonfinal185   static constexpr TTarget* singleton() noexcept {
186     return nullptr;
187   }
188 };
189 
190 template <class TTarget, class ToNullType, class FromNullType>
assign_ptr_(TTarget * rhs)191 TTarget* assign_ptr_(TTarget* rhs) {
192   if (FromNullType::singleton() == rhs) {
193     return ToNullType::singleton();
194   } else {
195     return rhs;
196   }
197 }
198 
199 // Increment needs to be acquire-release to make use_count() and
200 // unique() reliable.
atomic_refcount_increment(std::atomic<uint32_t> & refcount)201 inline uint32_t atomic_refcount_increment(std::atomic<uint32_t>& refcount) {
202   return refcount.fetch_add(1, std::memory_order_acq_rel) + 1;
203 }
204 
205 // weak_use_count() is only used for testing, so we don't need it to
206 // be reliable. Relaxed should be fine.
atomic_weakcount_increment(std::atomic<uint32_t> & weakcount)207 inline uint32_t atomic_weakcount_increment(std::atomic<uint32_t>& weakcount) {
208   return weakcount.fetch_add(1, std::memory_order_relaxed) + 1;
209 }
210 
211 // Both decrements need to be acquire-release for correctness. See
212 // e.g. std::shared_ptr implementation.
atomic_refcount_decrement(std::atomic<uint32_t> & refcount)213 inline uint32_t atomic_refcount_decrement(std::atomic<uint32_t>& refcount) {
214   return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
215 }
216 
atomic_weakcount_decrement(std::atomic<uint32_t> & weakcount)217 inline uint32_t atomic_weakcount_decrement(std::atomic<uint32_t>& weakcount) {
218   return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
219 }
220 
221 } // namespace detail
222 
223 template <class TTarget, class NullType>
224 class weak_intrusive_ptr;
225 
226 template <
227     class TTarget,
228     class NullType = detail::intrusive_target_default_null_type<TTarget>>
229 class intrusive_ptr final {
230  private:
231 //  the following static assert would be nice to have but it requires
232 //  the target class T to be fully defined when intrusive_ptr<T> is instantiated
233 //  this is a problem for classes that contain pointers to themselves
234 //  static_assert(
235 //      std::is_base_of<intrusive_ptr_target, TTarget>::value,
236 //      "intrusive_ptr can only be used for classes that inherit from
237 //      intrusive_ptr_target.");
238 #ifndef _WIN32
239   // This static_assert triggers on MSVC
240   //  error C2131: expression did not evaluate to a constant
241   static_assert(
242       // NOLINTNEXTLINE(misc-redundant-expression)
243       NullType::singleton() == NullType::singleton(),
244       "NullType must have a constexpr singleton() method");
245 #endif
246   static_assert(
247       std::is_base_of_v<
248           TTarget,
249           std::remove_pointer_t<decltype(NullType::singleton())>>,
250       "NullType::singleton() must return a element_type* pointer");
251 
252   TTarget* target_;
253 
254   template <typename T>
255   friend struct ExclusivelyOwnedTensorTraits;
256   template <class TTarget2, class NullType2>
257   friend class intrusive_ptr;
258   friend class weak_intrusive_ptr<TTarget, NullType>;
259 
260   // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom
261   // smart holder in pybind11 could access the private constructor of
262   // intrusive_ptr(T*) which took the ownership of the object. This is required
263   // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses
264   // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For
265   // details, see
266   // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers
267   template <typename, typename...>
268   friend class pybind11::class_;
269 
retain_()270   void retain_() {
271     if (target_ != NullType::singleton()) {
272       uint32_t new_refcount =
273           detail::atomic_refcount_increment(target_->refcount_);
274       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
275           new_refcount != 1,
276           "intrusive_ptr: Cannot increase refcount after it reached zero.");
277     }
278   }
279 
reset_()280   void reset_() noexcept {
281     if (target_ != NullType::singleton() &&
282         detail::atomic_refcount_decrement(target_->refcount_) == 0) {
283       // See comment above about weakcount. As long as refcount>0,
284       // weakcount is one larger than the actual number of weak references.
285       // So we need to decrement it here.
286       bool should_delete =
287           target_->weakcount_.load(std::memory_order_acquire) == 1;
288       if (!should_delete) {
289         // justification for const_cast: release_resources is basically a
290         // destructor and a destructor always mutates the object, even for const
291         // objects. NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
292         const_cast<std::remove_const_t<TTarget>*>(target_)->release_resources();
293         should_delete =
294             detail::atomic_weakcount_decrement(target_->weakcount_) == 0;
295       }
296       if (should_delete) {
297         delete target_;
298       }
299     }
300   }
301 
302   // raw pointer constructors are not public because we shouldn't make
303   // intrusive_ptr out of raw pointers except from inside the make_intrusive(),
304   // reclaim() and weak_intrusive_ptr::lock() implementations.
305 
306   // This constructor will increase the ref counter for you.
307   // This constructor will be used by the make_intrusive(), and also pybind11,
308   // which wrap the intrusive_ptr holder around the raw pointer and incref
309   // correspondingly (pybind11 requires raw pointer constructor to incref by
310   // default).
intrusive_ptr(TTarget * target)311   explicit intrusive_ptr(TTarget* target)
312       : intrusive_ptr(target, raw::DontIncreaseRefcount{}) {
313     if (target_ != NullType::singleton()) {
314       // We just created result.target_, so we know no other thread has
315       // access to it, so we know we needn't care about memory ordering.
316       // (On x86_64, a store with memory_order_relaxed generates a plain old
317       // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is
318       // much more expensive: https://godbolt.org/z/eKPzj8.)
319       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
320           target_->refcount_ == 0 && target_->weakcount_ == 0,
321           "intrusive_ptr: Newly-created target had non-zero refcounts. Does its "
322           "constructor do something strange like incref or create an "
323           "intrusive_ptr from `this`?");
324       target_->refcount_.store(1, std::memory_order_relaxed);
325       target_->weakcount_.store(1, std::memory_order_relaxed);
326     }
327   }
328 
329  public:
330   using element_type = TTarget;
331 
intrusive_ptr()332   intrusive_ptr() noexcept
333       : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
334 
intrusive_ptr(std::nullptr_t)335   intrusive_ptr(std::nullptr_t) noexcept
336       : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
337 
338   // This constructor will not increase the ref counter for you.
339   // We use the tagged dispatch mechanism to explicitly mark this constructor
340   // to not increase the refcount
intrusive_ptr(TTarget * target,raw::DontIncreaseRefcount)341   explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
342       : target_(target) {}
343 
intrusive_ptr(std::unique_ptr<TTarget> rhs)344   explicit intrusive_ptr(std::unique_ptr<TTarget> rhs) noexcept
345       : intrusive_ptr(rhs.release()) {}
346 
intrusive_ptr(intrusive_ptr && rhs)347   intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
348     rhs.target_ = NullType::singleton();
349   }
350 
351   template <class From, class FromNullType>
352   // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
intrusive_ptr(intrusive_ptr<From,FromNullType> && rhs)353   /* implicit */ intrusive_ptr(intrusive_ptr<From, FromNullType>&& rhs) noexcept
354       : target_(
355             detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
356     static_assert(
357         std::is_convertible<From*, TTarget*>::value,
358         "Type mismatch. intrusive_ptr move constructor got pointer of wrong type.");
359     rhs.target_ = FromNullType::singleton();
360   }
361 
intrusive_ptr(const intrusive_ptr & rhs)362   intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) {
363     retain_();
364   }
365 
366   template <class From, class FromNullType>
intrusive_ptr(const intrusive_ptr<From,FromNullType> & rhs)367   /* implicit */ intrusive_ptr(const intrusive_ptr<From, FromNullType>& rhs)
368       : target_(
369             detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
370     static_assert(
371         std::is_convertible<From*, TTarget*>::value,
372         "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type.");
373     retain_();
374   }
375 
~intrusive_ptr()376   ~intrusive_ptr() noexcept {
377     reset_();
378   }
379 
380   intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept {
381     // NOLINTNEXTLINE(*assign*)
382     return this->template operator= <TTarget, NullType>(std::move(rhs));
383   }
384 
385   template <class From, class FromNullType>
386   intrusive_ptr& operator=(intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
387     static_assert(
388         std::is_convertible<From*, TTarget*>::value,
389         "Type mismatch. intrusive_ptr move assignment got pointer of wrong type.");
390     intrusive_ptr tmp = std::move(rhs);
391     swap(tmp);
392     return *this;
393   }
394 
395   // Assignment is implemented using copy and swap. That's safe for self
396   // assignment.
397   // NOLINTNEXTLINE(bugprone-unhandled-self-assignment)
398   intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept {
399     // NOLINTNEXTLINE(*assign-operator, *assignment-signature)
400     return this->template operator= <TTarget, NullType>(rhs);
401   }
402 
403   template <class From, class FromNullType>
404   intrusive_ptr& operator=(
405       const intrusive_ptr<From, NullType>& rhs) & noexcept {
406     static_assert(
407         std::is_convertible<From*, TTarget*>::value,
408         "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type.");
409     intrusive_ptr tmp = rhs;
410     swap(tmp);
411     return *this;
412   }
413 
get()414   TTarget* get() const noexcept {
415     return target_;
416   }
417 
418   TTarget& operator*() const noexcept {
419     return *target_;
420   }
421 
422   TTarget* operator->() const noexcept {
423     return target_;
424   }
425 
426   operator bool() const noexcept {
427     return target_ != NullType::singleton();
428   }
429 
reset()430   void reset() noexcept {
431     reset_();
432     target_ = NullType::singleton();
433   }
434 
swap(intrusive_ptr & rhs)435   void swap(intrusive_ptr& rhs) noexcept {
436     std::swap(target_, rhs.target_);
437   }
438 
439   // We do a lot of null-pointer checks in our code, good to have this be cheap.
defined()440   bool defined() const noexcept {
441     return target_ != NullType::singleton();
442   }
443 
use_count()444   uint32_t use_count() const noexcept {
445     if (target_ == NullType::singleton()) {
446       return 0;
447     }
448     return target_->refcount_.load(std::memory_order_acquire);
449   }
450 
weak_use_count()451   uint32_t weak_use_count() const noexcept {
452     if (target_ == NullType::singleton()) {
453       return 0;
454     }
455     return target_->weakcount_.load(std::memory_order_acquire);
456   }
457 
unique()458   bool unique() const noexcept {
459     return use_count() == 1;
460   }
461 
462   /**
463    * Returns an owning (!) pointer to the underlying object and makes the
464    * intrusive_ptr instance invalid. That means the refcount is not decreased.
465    * You *must* put the returned pointer back into a intrusive_ptr using
466    * intrusive_ptr::reclaim(ptr) to properly destruct it.
467    * This is helpful for C APIs.
468    */
release()469   TTarget* release() noexcept {
470     // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
471     TTarget* result = target_;
472     target_ = NullType::singleton();
473     return result;
474   }
475 
476   /**
477    * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes
478    * over ownership. That means the refcount is not increased.
479    * This is the counter-part to intrusive_ptr::release() and the pointer
480    * passed in *must* have been created using intrusive_ptr::release().
481    */
reclaim(TTarget * owning_ptr)482   static intrusive_ptr reclaim(TTarget* owning_ptr) {
483     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
484         owning_ptr == NullType::singleton() ||
485             owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(),
486         "TTarget violates the invariant that refcount > 0  =>  weakcount > 0");
487     return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{});
488   }
489 
490   /**
491    * Takes an owning pointer to TTarget* and creates an intrusive_ptr
492    * representing a new reference, i.e. the raw pointer retains
493    * ownership.
494    */
reclaim_copy(TTarget * owning_ptr)495   static intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
496     auto ret = reclaim(owning_ptr);
497     ret.retain_();
498     return ret;
499   }
500 
501   /**
502    * Allocate a heap object with args and wrap it inside a intrusive_ptr and
503    * incref. This is a helper function to let make_intrusive() access private
504    * intrusive_ptr constructors.
505    */
506   template <class... Args>
make(Args &&...args)507   static intrusive_ptr make(Args&&... args) {
508     return intrusive_ptr(new TTarget(std::forward<Args>(args)...));
509   }
510 
511   /**
512    * Turn a new instance of TTarget (e.g., literally allocated
513    * using new TTarget(...) into an intrusive_ptr.  If possible,
514    * use intrusive_ptr::make instead which statically guarantees
515    * that the allocation was done properly.
516    *
517    * At the moment, the only reason this method exists is because
518    * pybind11 holder types expect to be able to allocate in
519    * this way (because pybind11 handles the new allocation itself).
520    */
unsafe_steal_from_new(TTarget * raw_ptr)521   static intrusive_ptr unsafe_steal_from_new(TTarget* raw_ptr) {
522     return intrusive_ptr(raw_ptr);
523   }
524 
525   /**
526    * Turn an instance of TTarget that should not be reference counted
527    * (e.g., allocated into an arena with placement new) into an
528    * intrusive_ptr. This is gratuitously unsafe and should only be
529    * used if you can guarantee that the pointer will not escape and be
530    * refcounted as normal.
531    *
532    * `expected_decrefs` is a debugging parameter: it indicates the
533    * number of strong owners the intrusive_ptr_target in question is
534    * expected to get. In most use cases, this will likely be 1.
535    *
536    * The reason this method exists is for manually sharing
537    * StorageImpls across Tensors in the static runtime. It needs
538    * access to private intrusive_ptr members so that the refcounts can
539    * be initialized to custom values.
540    */
unsafe_adapt_non_heap_allocated(TTarget * raw_ptr,uint32_t expected_decrefs)541   static intrusive_ptr unsafe_adapt_non_heap_allocated(
542       TTarget* raw_ptr,
543       uint32_t expected_decrefs) {
544     intrusive_ptr result(raw_ptr, raw::DontIncreaseRefcount{});
545     // kImpracticallyHugeReferenceCount is impractically huge for a reference
546     // count, while being in no danger of overflowing uint32_t. We actually only
547     // need to initialize the refcount to 2 -- we are just doing an unbalanced
548     // incref to prevent the non-heap-allocated target from being
549     // freed, and we are optimizing that incref by directly
550     // initializing the refcounts rather than doing an expensive
551     // atomic increment. The reason to use kImpracticallyHugeReferenceCount is
552     // to accommodate the debug assertions in ~intrusive_ptr_target.
553 #ifdef NDEBUG
554     expected_decrefs = 0;
555 #endif
556     result.target_->refcount_.store(
557         detail::kImpracticallyHugeReferenceCount + expected_decrefs,
558         std::memory_order_relaxed);
559     result.target_->weakcount_.store(
560         detail::kImpracticallyHugeReferenceCount, std::memory_order_relaxed);
561     return result;
562   }
563 
564   /**
565    * Turn a **non-owning raw pointer** to an intrusive_ptr.  It is
566    * the moral equivalent of enable_shared_from_this on a shared pointer.
567    *
568    * This method is only valid for objects that are already live.  If
569    * you are looking for the moral equivalent of unique_ptr<T>(T*)
570    * constructor, see steal_from_new.
571    *
572    * TODO: https://github.com/pytorch/pytorch/issues/56482
573    */
unsafe_reclaim_from_nonowning(TTarget * raw_ptr)574   static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) {
575     // See Note [Stack allocated intrusive_ptr_target safety]
576     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
577         raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0,
578         "intrusive_ptr: Can only reclaim pointers that are owned by someone");
579     auto ptr = reclaim(raw_ptr); // doesn't increase refcount
580     ptr.retain_();
581     return ptr;
582   }
583 };
584 
585 template <
586     class TTarget,
587     class NullType = detail::intrusive_target_default_null_type<TTarget>,
588     class... Args>
make_intrusive(Args &&...args)589 inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) {
590   return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...);
591 }
592 
593 template <class TTarget, class NullType>
swap(intrusive_ptr<TTarget,NullType> & lhs,intrusive_ptr<TTarget,NullType> & rhs)594 inline void swap(
595     intrusive_ptr<TTarget, NullType>& lhs,
596     intrusive_ptr<TTarget, NullType>& rhs) noexcept {
597   lhs.swap(rhs);
598 }
599 
600 // To allow intrusive_ptr inside std::map or std::set, we need operator<
601 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
602 inline bool operator<(
603     const intrusive_ptr<TTarget1, NullType1>& lhs,
604     const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
605   return lhs.get() < rhs.get();
606 }
607 
608 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
609 inline bool operator==(
610     const intrusive_ptr<TTarget1, NullType1>& lhs,
611     const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
612   return lhs.get() == rhs.get();
613 }
614 
615 template <class TTarget1, class NullType1>
616 inline bool operator==(
617     const intrusive_ptr<TTarget1, NullType1>& lhs,
618     std::nullptr_t) noexcept {
619   return lhs.get() == nullptr;
620 }
621 
622 template <class TTarget2, class NullType2>
623 inline bool operator==(
624     std::nullptr_t,
625     const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
626   return nullptr == rhs.get();
627 }
628 
629 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
630 inline bool operator!=(
631     const intrusive_ptr<TTarget1, NullType1>& lhs,
632     const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
633   return !operator==(lhs, rhs);
634 }
635 
636 template <class TTarget1, class NullType1>
637 inline bool operator!=(
638     const intrusive_ptr<TTarget1, NullType1>& lhs,
639     std::nullptr_t) noexcept {
640   return !operator==(lhs, nullptr);
641 }
642 
643 template <class TTarget2, class NullType2>
644 inline bool operator!=(
645     std::nullptr_t,
646     const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
647   return !operator==(nullptr, rhs);
648 }
649 template <typename T>
650 struct MaybeOwnedTraits<c10::intrusive_ptr<T>> {
651   using owned_type = c10::intrusive_ptr<T>;
652   using borrow_type = c10::intrusive_ptr<T>;
653 
654   static borrow_type createBorrow(const owned_type& from) {
655     return borrow_type::reclaim(from.get());
656   }
657 
658   static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
659     lhs.release();
660     lhs = borrow_type::reclaim(rhs.get());
661   }
662 
663   static void destroyBorrow(borrow_type& toDestroy) {
664     toDestroy.release();
665   }
666 
667   static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
668     return borrow;
669   }
670 
671   static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
672     return &borrow;
673   }
674 
675   static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
676     return true;
677   }
678 };
679 
680 template <
681     typename TTarget,
682     class NullType = detail::intrusive_target_default_null_type<TTarget>>
683 class weak_intrusive_ptr final {
684  private:
685   static_assert(
686       std::is_base_of_v<intrusive_ptr_target, TTarget>,
687       "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target.");
688 #ifndef _WIN32
689   // This static_assert triggers on MSVC
690   //  error C2131: expression did not evaluate to a constant
691   static_assert(
692       NullType::singleton() == NullType::singleton(),
693       "NullType must have a constexpr singleton() method");
694 #endif
695   static_assert(
696       std::is_base_of_v<
697           TTarget,
698           std::remove_pointer_t<decltype(NullType::singleton())>>,
699       "NullType::singleton() must return a element_type* pointer");
700 
701   TTarget* target_;
702 
703   template <class TTarget2, class NullType2>
704   friend class weak_intrusive_ptr;
705 
706   void retain_() {
707     if (target_ != NullType::singleton()) {
708       uint32_t new_weakcount =
709           detail::atomic_weakcount_increment(target_->weakcount_);
710       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
711           new_weakcount != 1,
712           "weak_intrusive_ptr: Cannot increase weakcount after it reached zero.");
713     }
714   }
715 
716   void reset_() noexcept {
717     if (target_ != NullType::singleton() &&
718         detail::atomic_weakcount_decrement(target_->weakcount_) == 0) {
719       // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
720       delete target_;
721     }
722     target_ = NullType::singleton();
723   }
724 
725   constexpr explicit weak_intrusive_ptr(TTarget* target) : target_(target) {}
726 
727  public:
728   using element_type = TTarget;
729 
730   explicit weak_intrusive_ptr(const intrusive_ptr<TTarget, NullType>& ptr)
731       : weak_intrusive_ptr(ptr.get()) {
732     retain_();
733   }
734 
735   weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
736     rhs.target_ = NullType::singleton();
737   }
738 
739   template <class From, class FromNullType>
740   /* implicit */ weak_intrusive_ptr(
741       // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
742       weak_intrusive_ptr<From, FromNullType>&& rhs) noexcept
743       : target_(
744             detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
745     static_assert(
746         std::is_convertible<From*, TTarget*>::value,
747         "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type.");
748     rhs.target_ = FromNullType::singleton();
749   }
750 
751   weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) {
752     retain_();
753   }
754 
755   template <class From, class FromNullType>
756   /* implicit */ weak_intrusive_ptr(
757       const weak_intrusive_ptr<From, FromNullType>& rhs)
758       : target_(
759             detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
760     static_assert(
761         std::is_convertible<From*, TTarget*>::value,
762         "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type.");
763     retain_();
764   }
765 
766   ~weak_intrusive_ptr() noexcept {
767     reset_();
768   }
769 
770   weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept {
771     // NOLINTNEXTLINE(*assign*)
772     return this->template operator= <TTarget, NullType>(std::move(rhs));
773   }
774 
775   template <class From, class FromNullType>
776   weak_intrusive_ptr& operator=(
777       weak_intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
778     static_assert(
779         std::is_convertible<From*, TTarget*>::value,
780         "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type.");
781     weak_intrusive_ptr tmp = std::move(rhs);
782     swap(tmp);
783     return *this;
784   }
785 
786   weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept {
787     if (this == &rhs) {
788       return *this;
789     }
790     // NOLINTNEXTLINE(*assign*)
791     return this->template operator= <TTarget, NullType>(rhs);
792   }
793 
794   weak_intrusive_ptr& operator=(
795       const intrusive_ptr<TTarget, NullType>& rhs) & noexcept {
796     weak_intrusive_ptr tmp(rhs);
797     swap(tmp);
798     return *this;
799   }
800 
801   template <class From, class FromNullType>
802   weak_intrusive_ptr& operator=(
803       const weak_intrusive_ptr<From, NullType>& rhs) & noexcept {
804     static_assert(
805         std::is_convertible<From*, TTarget*>::value,
806         "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type.");
807     weak_intrusive_ptr tmp = rhs;
808     swap(tmp);
809     return *this;
810   }
811 
812   void reset() noexcept {
813     reset_();
814   }
815 
816   void swap(weak_intrusive_ptr& rhs) noexcept {
817     TTarget* tmp = target_;
818     target_ = rhs.target_;
819     rhs.target_ = tmp;
820   }
821 
822   // NB: This should ONLY be used by the std::hash implementation
823   // for weak_intrusive_ptr.  Another way you could do this is
824   // friend std::hash<weak_intrusive_ptr>, but this triggers two
825   // bugs:
826   //
827   //  (1) It triggers an nvcc bug, where std::hash in a friend class
828   //      declaration gets preprocessed into hash, which then cannot
829   //      actually be found.  The error in this case looks like:
830   //
831   //        error: no template named 'hash'; did you mean 'std::hash'?
832   //
833   //  (2) On OS X, std::hash is declared as a struct, not a class.
834   //      This twings:
835   //
836   //        error: class 'hash' was previously declared as a struct
837   //        [-Werror,-Wmismatched-tags]
838   //
839   // Both of these are work-aroundable, but on the whole, I decided
840   // it would be simpler and easier to make work if we just expose
841   // an unsafe getter for target_
842   //
843   TTarget* _unsafe_get_target() const noexcept {
844     return target_;
845   }
846 
847   uint32_t use_count() const noexcept {
848     if (target_ == NullType::singleton()) {
849       return 0;
850     }
851     return target_->refcount_.load(
852         std::memory_order_acquire); // refcount, not weakcount!
853   }
854 
855   uint32_t weak_use_count() const noexcept {
856     if (target_ == NullType::singleton()) {
857       return 0;
858     }
859     return target_->weakcount_.load(std::memory_order_acquire);
860   }
861 
862   bool expired() const noexcept {
863     return use_count() == 0;
864   }
865 
866   intrusive_ptr<TTarget, NullType> lock() const noexcept {
867     if (expired()) {
868       return intrusive_ptr<TTarget, NullType>();
869     } else {
870       auto refcount = target_->refcount_.load(std::memory_order_seq_cst);
871       do {
872         if (refcount == 0) {
873           // Object already destructed, no strong references left anymore.
874           // Return nullptr.
875           return intrusive_ptr<TTarget, NullType>();
876         }
877       } while (
878           !target_->refcount_.compare_exchange_weak(refcount, refcount + 1));
879       return intrusive_ptr<TTarget, NullType>(
880           target_, raw::DontIncreaseRefcount{});
881     }
882   }
883 
884   /**
885    * Returns an owning (but still only weakly referenced) pointer to the
886    * underlying object and makes the weak_intrusive_ptr instance invalid.
887    * That means the weakcount is not decreased.
888    * You *must* put the returned pointer back into a weak_intrusive_ptr using
889    * weak_intrusive_ptr::reclaim(ptr) to properly destruct it.
890    * This is helpful for C APIs.
891    */
892   TTarget* release() noexcept {
893     TTarget* result = target_;
894     target_ = NullType::singleton();
895     return result;
896   }
897 
898   /**
899    * Takes an owning (but must be weakly referenced) pointer to TTarget* and
900    * creates a weak_intrusive_ptr that takes over ownership.
901    * This means that the weakcount is not increased.
902    * This is the counter-part to weak_intrusive_ptr::release() and the pointer
903    * passed in *must* have been created using weak_intrusive_ptr::release().
904    */
905   static weak_intrusive_ptr reclaim(TTarget* owning_weak_ptr) {
906     // See Note [Stack allocated intrusive_ptr_target safety]
907     // if refcount > 0, weakcount must be >1 for weak references to exist.
908     // see weak counting explanation at top of this file.
909     // if refcount == 0, weakcount only must be >0.
910     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
911         owning_weak_ptr == NullType::singleton() ||
912             owning_weak_ptr->weakcount_.load() > 1 ||
913             (owning_weak_ptr->refcount_.load() == 0 &&
914              owning_weak_ptr->weakcount_.load() > 0),
915         "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release().");
916     return weak_intrusive_ptr(owning_weak_ptr);
917   }
918 
919   /**
920    * Takes a pointer to TTarget* (may be weak or strong) and creates a
921    * new weak_intrusive_ptr representing a new weak reference, i.e.
922    * the raw pointer retains ownership.
923    */
924   static weak_intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
925     auto ret = reclaim(owning_ptr);
926     ret.retain_();
927     return ret;
928   }
929 
930   template <class TTarget1, class NullType1, class TTarget2, class NullType2>
931   friend bool operator<(
932       const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
933       const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
934   template <class TTarget1, class NullType1, class TTarget2, class NullType2>
935   friend bool operator==(
936       const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
937       const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
938 };
939 
940 template <class TTarget, class NullType>
941 inline void swap(
942     weak_intrusive_ptr<TTarget, NullType>& lhs,
943     weak_intrusive_ptr<TTarget, NullType>& rhs) noexcept {
944   lhs.swap(rhs);
945 }
946 
947 // To allow weak_intrusive_ptr inside std::map or std::set, we need operator<
948 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
949 inline bool operator<(
950     const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
951     const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
952   return lhs.target_ < rhs.target_;
953 }
954 
955 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
956 inline bool operator==(
957     const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
958     const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
959   return lhs.target_ == rhs.target_;
960 }
961 
962 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
963 inline bool operator!=(
964     const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
965     const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
966   return !operator==(lhs, rhs);
967 }
968 
969 // Alias for documentary purposes, to more easily distinguish
970 // weak raw intrusive pointers from intrusive pointers.
971 using weak_intrusive_ptr_target = intrusive_ptr_target;
972 
973 // This namespace provides some methods for working with
974 // raw pointers that subclass intrusive_ptr_target.  They are not provided
975 // as methods on intrusive_ptr_target, because ideally you would not need these
976 // methods at all (use smart pointers), but if you are dealing with legacy code
977 // that still needs to pass around raw pointers, you may find these quite
978 // useful.
979 //
980 // An important usage note: some functions are only valid if you have a
981 // strong raw pointer to the object, while others are only valid if you
982 // have a weak raw pointer to the object.  ONLY call intrusive_ptr namespace
983 // functions on strong pointers, and weak_intrusive_ptr namespace functions
984 // on weak pointers.  If you mix it up, you may get an assert failure.
985 namespace raw {
986 
987 namespace intrusive_ptr {
988 
989 // WARNING: Unlike the reclaim() API, it is NOT valid to pass
990 // NullType::singleton to this function
991 inline void incref(intrusive_ptr_target* self) {
992   if (self) {
993     detail::atomic_refcount_increment(self->refcount_);
994   }
995 }
996 
997 // WARNING: Unlike the reclaim() API, it is NOT valid to pass
998 // NullType::singleton to this function
999 inline void decref(intrusive_ptr_target* self) {
1000   // Let it die
1001   c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1002   // NB: Caller still has 'self' pointer, but it's now invalid.
1003   // If you want more safety, used the actual c10::intrusive_ptr class
1004 }
1005 
1006 template <typename T>
1007 inline T* make_weak(T* self) {
1008   // NB: 'this' is a strong pointer, but we return a weak pointer
1009   auto ptr = c10::intrusive_ptr<T>::reclaim(self);
1010   c10::weak_intrusive_ptr<T> wptr(ptr);
1011   ptr.release();
1012   return wptr.release();
1013 }
1014 
1015 inline uint32_t use_count(intrusive_ptr_target* self) {
1016   auto ptr = c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1017   auto r = ptr.use_count();
1018   ptr.release();
1019   return r;
1020 }
1021 
1022 } // namespace intrusive_ptr
1023 
1024 namespace weak_intrusive_ptr {
1025 
1026 inline void incref(weak_intrusive_ptr_target* self) {
1027   detail::atomic_weakcount_increment(self->weakcount_);
1028 }
1029 
1030 inline void decref(weak_intrusive_ptr_target* self) {
1031   // Let it die
1032   c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1033   // NB: You still "have" the 'self' pointer, but it's now invalid.
1034   // If you want more safety, used the actual c10::weak_intrusive_ptr class
1035 }
1036 
1037 template <typename T>
1038 inline T* lock(T* self) {
1039   auto wptr = c10::weak_intrusive_ptr<T>::reclaim(self);
1040   auto ptr = wptr.lock();
1041   wptr.release();
1042   return ptr.release();
1043 }
1044 
1045 // This gives the STRONG refcount of a WEAK pointer
1046 inline uint32_t use_count(weak_intrusive_ptr_target* self) {
1047   auto wptr = c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1048   auto r = wptr.use_count();
1049   wptr.release();
1050   return r;
1051 }
1052 
1053 } // namespace weak_intrusive_ptr
1054 
1055 } // namespace raw
1056 
1057 } // namespace c10
1058 
1059 namespace std {
1060 // To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or
1061 // std::unordered_set, we need std::hash
1062 template <class TTarget, class NullType>
1063 struct hash<c10::intrusive_ptr<TTarget, NullType>> {
1064   size_t operator()(const c10::intrusive_ptr<TTarget, NullType>& x) const {
1065     return std::hash<TTarget*>()(x.get());
1066   }
1067 };
1068 template <class TTarget, class NullType>
1069 struct hash<c10::weak_intrusive_ptr<TTarget, NullType>> {
1070   size_t operator()(const c10::weak_intrusive_ptr<TTarget, NullType>& x) const {
1071     return std::hash<TTarget*>()(x._unsafe_get_target());
1072   }
1073 };
1074 } // namespace std
1075