1 #pragma once
2
3 #include <c10/util/Exception.h>
4 #include <c10/util/MaybeOwned.h>
5 #include <atomic>
6 #include <climits>
7 #include <memory>
8 #include <type_traits>
9
10 namespace pybind11 {
11 template <typename, typename...>
12 class class_;
13 }
14
15 namespace c10 {
16 class intrusive_ptr_target;
17 namespace raw {
18 namespace weak_intrusive_ptr {
19 inline void incref(intrusive_ptr_target* self);
20 }
21 namespace intrusive_ptr {
22 inline void incref(intrusive_ptr_target* self);
23 }
24
25 // constructor tag used by intrusive_ptr constructors
26 struct DontIncreaseRefcount {};
27 } // namespace raw
28
29 namespace detail {
30 constexpr uint32_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF;
31 } // namespace detail
32
33 /**
34 * intrusive_ptr<T> is an alternative to shared_ptr<T> that has better
35 * performance because it does the refcounting intrusively
36 * (i.e. in a member of the object itself).
37 * Your class T needs to inherit from intrusive_ptr_target to allow it to be
38 * used in an intrusive_ptr<T>. Your class's constructor should not allow
39 *`this` to escape to other threads or create an intrusive_ptr from `this`.
40 */
41
42 // Note [Stack allocated intrusive_ptr_target safety]
43 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
44 // A well known problem with std::enable_shared_from_this is that it
45 // allows you to create a std::shared_ptr from a stack allocated object,
46 // which is totally bogus because the object will die once you return
47 // from the stack. In intrusive_ptr, we can detect that this has occurred,
48 // because we set the refcount/weakcount of objects which inherit from
49 // intrusive_ptr_target to zero, *unless* we can prove that the object
50 // was dynamically allocated (e.g., via make_intrusive).
51 //
52 // Thus, whenever you transmute a T* into a intrusive_ptr<T>, we check
53 // and make sure that the refcount isn't zero (or, a more subtle
54 // test for weak_intrusive_ptr<T>, for which the refcount may validly
55 // be zero, but the weak refcount better not be zero), because that
56 // tells us if the object was allocated by us. If it wasn't, no
57 // intrusive_ptr for you!
58
59 // NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor)
60 class C10_API intrusive_ptr_target {
61 // Note [Weak references for intrusive refcounting]
62 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
63 // Here's the scheme:
64 //
65 // - refcount == number of strong references to the object
66 // weakcount == number of weak references to the object,
67 // plus one more if refcount > 0
68 // An invariant: refcount > 0 => weakcount > 0
69 //
70 // - c10::StorageImpl stays live as long as there are any strong
71 // or weak pointers to it (weakcount > 0, since strong
72 // references count as a +1 to weakcount)
73 //
74 // - finalizers are called and data_ptr is deallocated when refcount == 0
75 //
76 // - Once refcount == 0, it can never again be > 0 (the transition
77 // from > 0 to == 0 is monotonic)
78 //
79 // - When you access c10::StorageImpl via a weak pointer, you must
80 // atomically increment the use count, if it is greater than 0.
81 // If it is not, you must report that the storage is dead.
82 //
83 mutable std::atomic<uint32_t> refcount_;
84 mutable std::atomic<uint32_t> weakcount_;
85
86 template <typename T, typename NullType>
87 friend class intrusive_ptr;
88 friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self);
89
90 template <typename T, typename NullType>
91 friend class weak_intrusive_ptr;
92 friend inline void raw::weak_intrusive_ptr::incref(
93 intrusive_ptr_target* self);
94
95 template <typename T>
96 friend struct ExclusivelyOwnedTensorTraits;
97
98 protected:
99 // protected destructor. We never want to destruct intrusive_ptr_target*
100 // directly.
~intrusive_ptr_target()101 virtual ~intrusive_ptr_target() {
102 // Disable -Wterminate and -Wexceptions so we're allowed to use assertions
103 // (i.e. throw exceptions) in a destructor.
104 // We also have to disable -Wunknown-warning-option and -Wpragmas, because
105 // some other compilers don't know about -Wterminate or -Wexceptions and
106 // will show a warning about unknown warning options otherwise.
107 #if defined(_MSC_VER) && !defined(__clang__)
108 #pragma warning(push)
109 #pragma warning( \
110 disable : 4297) // function assumed not to throw an exception but does
111 #else
112 #pragma GCC diagnostic push
113 #pragma GCC diagnostic ignored "-Wpragmas"
114 #pragma GCC diagnostic ignored "-Wunknown-warning-option"
115 #pragma GCC diagnostic ignored "-Wterminate"
116 #pragma GCC diagnostic ignored "-Wexceptions"
117 #endif
118 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
119 // Second condition is there to accommodate
120 // unsafe_adapt_non_heap_allocated: since we are doing our own
121 // deallocation in that case, it is correct for each
122 // expected_decref to have happened (some user code tried to
123 // decref and thus free the object, but it didn't happen right
124 // away) or not (no user code tried to free the object, and
125 // now it's getting destroyed through whatever mechanism the
126 // caller of unsafe_adapt_non_heap_allocated wanted to
127 // use). We choose our reference count such that the count
128 // will not dip below kImpracticallyHugeReferenceCount regardless.
129 refcount_.load() == 0 ||
130 refcount_.load() >= detail::kImpracticallyHugeReferenceCount,
131 "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ",
132 refcount_.load());
133 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
134 // See ~intrusive_ptr for optimization that will frequently result in 1
135 // at destruction time.
136 weakcount_.load() == 1 || weakcount_.load() == 0 ||
137 weakcount_.load() == detail::kImpracticallyHugeReferenceCount - 1 ||
138 weakcount_.load() == detail::kImpracticallyHugeReferenceCount,
139 "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it");
140 #if defined(_MSC_VER) && !defined(__clang__)
141 #pragma warning(pop)
142 #else
143 #pragma GCC diagnostic pop
144 #endif
145 }
146
intrusive_ptr_target()147 constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {}
148
149 // intrusive_ptr_target supports copy and move: but refcount and weakcount
150 // don't participate (since they are intrinsic properties of the memory
151 // location)
intrusive_ptr_target(intrusive_ptr_target &&)152 intrusive_ptr_target(intrusive_ptr_target&& /*other*/) noexcept
153 : intrusive_ptr_target() {}
154
155 intrusive_ptr_target& operator=(intrusive_ptr_target&& /*other*/) noexcept {
156 return *this;
157 }
158
intrusive_ptr_target(const intrusive_ptr_target &)159 intrusive_ptr_target(const intrusive_ptr_target& /*other*/) noexcept
160 : intrusive_ptr_target() {}
161
162 intrusive_ptr_target& operator=(
163 const intrusive_ptr_target& /*other*/) noexcept {
164 return *this;
165 }
166
167 private:
168 /**
169 * This is called when refcount reaches zero.
170 * You can override this to release expensive resources.
171 * There might still be weak references, so your object might not get
172 * destructed yet, but you can assume the object isn't used anymore,
173 * i.e. no more calls to methods or accesses to members (we just can't
174 * destruct it yet because we need the weakcount accessible).
175 *
176 * If there are no weak references (i.e. your class is about to be
177 * destructed), this function WILL NOT be called.
178 */
release_resources()179 virtual void release_resources() {}
180 };
181
182 namespace detail {
183 template <class TTarget>
184 struct intrusive_target_default_null_type final {
singletonfinal185 static constexpr TTarget* singleton() noexcept {
186 return nullptr;
187 }
188 };
189
190 template <class TTarget, class ToNullType, class FromNullType>
assign_ptr_(TTarget * rhs)191 TTarget* assign_ptr_(TTarget* rhs) {
192 if (FromNullType::singleton() == rhs) {
193 return ToNullType::singleton();
194 } else {
195 return rhs;
196 }
197 }
198
199 // Increment needs to be acquire-release to make use_count() and
200 // unique() reliable.
atomic_refcount_increment(std::atomic<uint32_t> & refcount)201 inline uint32_t atomic_refcount_increment(std::atomic<uint32_t>& refcount) {
202 return refcount.fetch_add(1, std::memory_order_acq_rel) + 1;
203 }
204
205 // weak_use_count() is only used for testing, so we don't need it to
206 // be reliable. Relaxed should be fine.
atomic_weakcount_increment(std::atomic<uint32_t> & weakcount)207 inline uint32_t atomic_weakcount_increment(std::atomic<uint32_t>& weakcount) {
208 return weakcount.fetch_add(1, std::memory_order_relaxed) + 1;
209 }
210
211 // Both decrements need to be acquire-release for correctness. See
212 // e.g. std::shared_ptr implementation.
atomic_refcount_decrement(std::atomic<uint32_t> & refcount)213 inline uint32_t atomic_refcount_decrement(std::atomic<uint32_t>& refcount) {
214 return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
215 }
216
atomic_weakcount_decrement(std::atomic<uint32_t> & weakcount)217 inline uint32_t atomic_weakcount_decrement(std::atomic<uint32_t>& weakcount) {
218 return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
219 }
220
221 } // namespace detail
222
223 template <class TTarget, class NullType>
224 class weak_intrusive_ptr;
225
226 template <
227 class TTarget,
228 class NullType = detail::intrusive_target_default_null_type<TTarget>>
229 class intrusive_ptr final {
230 private:
231 // the following static assert would be nice to have but it requires
232 // the target class T to be fully defined when intrusive_ptr<T> is instantiated
233 // this is a problem for classes that contain pointers to themselves
234 // static_assert(
235 // std::is_base_of<intrusive_ptr_target, TTarget>::value,
236 // "intrusive_ptr can only be used for classes that inherit from
237 // intrusive_ptr_target.");
238 #ifndef _WIN32
239 // This static_assert triggers on MSVC
240 // error C2131: expression did not evaluate to a constant
241 static_assert(
242 // NOLINTNEXTLINE(misc-redundant-expression)
243 NullType::singleton() == NullType::singleton(),
244 "NullType must have a constexpr singleton() method");
245 #endif
246 static_assert(
247 std::is_base_of_v<
248 TTarget,
249 std::remove_pointer_t<decltype(NullType::singleton())>>,
250 "NullType::singleton() must return a element_type* pointer");
251
252 TTarget* target_;
253
254 template <typename T>
255 friend struct ExclusivelyOwnedTensorTraits;
256 template <class TTarget2, class NullType2>
257 friend class intrusive_ptr;
258 friend class weak_intrusive_ptr<TTarget, NullType>;
259
260 // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom
261 // smart holder in pybind11 could access the private constructor of
262 // intrusive_ptr(T*) which took the ownership of the object. This is required
263 // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses
264 // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For
265 // details, see
266 // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers
267 template <typename, typename...>
268 friend class pybind11::class_;
269
retain_()270 void retain_() {
271 if (target_ != NullType::singleton()) {
272 uint32_t new_refcount =
273 detail::atomic_refcount_increment(target_->refcount_);
274 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
275 new_refcount != 1,
276 "intrusive_ptr: Cannot increase refcount after it reached zero.");
277 }
278 }
279
reset_()280 void reset_() noexcept {
281 if (target_ != NullType::singleton() &&
282 detail::atomic_refcount_decrement(target_->refcount_) == 0) {
283 // See comment above about weakcount. As long as refcount>0,
284 // weakcount is one larger than the actual number of weak references.
285 // So we need to decrement it here.
286 bool should_delete =
287 target_->weakcount_.load(std::memory_order_acquire) == 1;
288 if (!should_delete) {
289 // justification for const_cast: release_resources is basically a
290 // destructor and a destructor always mutates the object, even for const
291 // objects. NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
292 const_cast<std::remove_const_t<TTarget>*>(target_)->release_resources();
293 should_delete =
294 detail::atomic_weakcount_decrement(target_->weakcount_) == 0;
295 }
296 if (should_delete) {
297 delete target_;
298 }
299 }
300 }
301
302 // raw pointer constructors are not public because we shouldn't make
303 // intrusive_ptr out of raw pointers except from inside the make_intrusive(),
304 // reclaim() and weak_intrusive_ptr::lock() implementations.
305
306 // This constructor will increase the ref counter for you.
307 // This constructor will be used by the make_intrusive(), and also pybind11,
308 // which wrap the intrusive_ptr holder around the raw pointer and incref
309 // correspondingly (pybind11 requires raw pointer constructor to incref by
310 // default).
intrusive_ptr(TTarget * target)311 explicit intrusive_ptr(TTarget* target)
312 : intrusive_ptr(target, raw::DontIncreaseRefcount{}) {
313 if (target_ != NullType::singleton()) {
314 // We just created result.target_, so we know no other thread has
315 // access to it, so we know we needn't care about memory ordering.
316 // (On x86_64, a store with memory_order_relaxed generates a plain old
317 // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is
318 // much more expensive: https://godbolt.org/z/eKPzj8.)
319 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
320 target_->refcount_ == 0 && target_->weakcount_ == 0,
321 "intrusive_ptr: Newly-created target had non-zero refcounts. Does its "
322 "constructor do something strange like incref or create an "
323 "intrusive_ptr from `this`?");
324 target_->refcount_.store(1, std::memory_order_relaxed);
325 target_->weakcount_.store(1, std::memory_order_relaxed);
326 }
327 }
328
329 public:
330 using element_type = TTarget;
331
intrusive_ptr()332 intrusive_ptr() noexcept
333 : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
334
intrusive_ptr(std::nullptr_t)335 intrusive_ptr(std::nullptr_t) noexcept
336 : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
337
338 // This constructor will not increase the ref counter for you.
339 // We use the tagged dispatch mechanism to explicitly mark this constructor
340 // to not increase the refcount
intrusive_ptr(TTarget * target,raw::DontIncreaseRefcount)341 explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
342 : target_(target) {}
343
intrusive_ptr(std::unique_ptr<TTarget> rhs)344 explicit intrusive_ptr(std::unique_ptr<TTarget> rhs) noexcept
345 : intrusive_ptr(rhs.release()) {}
346
intrusive_ptr(intrusive_ptr && rhs)347 intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
348 rhs.target_ = NullType::singleton();
349 }
350
351 template <class From, class FromNullType>
352 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
intrusive_ptr(intrusive_ptr<From,FromNullType> && rhs)353 /* implicit */ intrusive_ptr(intrusive_ptr<From, FromNullType>&& rhs) noexcept
354 : target_(
355 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
356 static_assert(
357 std::is_convertible<From*, TTarget*>::value,
358 "Type mismatch. intrusive_ptr move constructor got pointer of wrong type.");
359 rhs.target_ = FromNullType::singleton();
360 }
361
intrusive_ptr(const intrusive_ptr & rhs)362 intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) {
363 retain_();
364 }
365
366 template <class From, class FromNullType>
intrusive_ptr(const intrusive_ptr<From,FromNullType> & rhs)367 /* implicit */ intrusive_ptr(const intrusive_ptr<From, FromNullType>& rhs)
368 : target_(
369 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
370 static_assert(
371 std::is_convertible<From*, TTarget*>::value,
372 "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type.");
373 retain_();
374 }
375
~intrusive_ptr()376 ~intrusive_ptr() noexcept {
377 reset_();
378 }
379
380 intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept {
381 // NOLINTNEXTLINE(*assign*)
382 return this->template operator= <TTarget, NullType>(std::move(rhs));
383 }
384
385 template <class From, class FromNullType>
386 intrusive_ptr& operator=(intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
387 static_assert(
388 std::is_convertible<From*, TTarget*>::value,
389 "Type mismatch. intrusive_ptr move assignment got pointer of wrong type.");
390 intrusive_ptr tmp = std::move(rhs);
391 swap(tmp);
392 return *this;
393 }
394
395 // Assignment is implemented using copy and swap. That's safe for self
396 // assignment.
397 // NOLINTNEXTLINE(bugprone-unhandled-self-assignment)
398 intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept {
399 // NOLINTNEXTLINE(*assign-operator, *assignment-signature)
400 return this->template operator= <TTarget, NullType>(rhs);
401 }
402
403 template <class From, class FromNullType>
404 intrusive_ptr& operator=(
405 const intrusive_ptr<From, NullType>& rhs) & noexcept {
406 static_assert(
407 std::is_convertible<From*, TTarget*>::value,
408 "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type.");
409 intrusive_ptr tmp = rhs;
410 swap(tmp);
411 return *this;
412 }
413
get()414 TTarget* get() const noexcept {
415 return target_;
416 }
417
418 TTarget& operator*() const noexcept {
419 return *target_;
420 }
421
422 TTarget* operator->() const noexcept {
423 return target_;
424 }
425
426 operator bool() const noexcept {
427 return target_ != NullType::singleton();
428 }
429
reset()430 void reset() noexcept {
431 reset_();
432 target_ = NullType::singleton();
433 }
434
swap(intrusive_ptr & rhs)435 void swap(intrusive_ptr& rhs) noexcept {
436 std::swap(target_, rhs.target_);
437 }
438
439 // We do a lot of null-pointer checks in our code, good to have this be cheap.
defined()440 bool defined() const noexcept {
441 return target_ != NullType::singleton();
442 }
443
use_count()444 uint32_t use_count() const noexcept {
445 if (target_ == NullType::singleton()) {
446 return 0;
447 }
448 return target_->refcount_.load(std::memory_order_acquire);
449 }
450
weak_use_count()451 uint32_t weak_use_count() const noexcept {
452 if (target_ == NullType::singleton()) {
453 return 0;
454 }
455 return target_->weakcount_.load(std::memory_order_acquire);
456 }
457
unique()458 bool unique() const noexcept {
459 return use_count() == 1;
460 }
461
462 /**
463 * Returns an owning (!) pointer to the underlying object and makes the
464 * intrusive_ptr instance invalid. That means the refcount is not decreased.
465 * You *must* put the returned pointer back into a intrusive_ptr using
466 * intrusive_ptr::reclaim(ptr) to properly destruct it.
467 * This is helpful for C APIs.
468 */
release()469 TTarget* release() noexcept {
470 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
471 TTarget* result = target_;
472 target_ = NullType::singleton();
473 return result;
474 }
475
476 /**
477 * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes
478 * over ownership. That means the refcount is not increased.
479 * This is the counter-part to intrusive_ptr::release() and the pointer
480 * passed in *must* have been created using intrusive_ptr::release().
481 */
reclaim(TTarget * owning_ptr)482 static intrusive_ptr reclaim(TTarget* owning_ptr) {
483 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
484 owning_ptr == NullType::singleton() ||
485 owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(),
486 "TTarget violates the invariant that refcount > 0 => weakcount > 0");
487 return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{});
488 }
489
490 /**
491 * Takes an owning pointer to TTarget* and creates an intrusive_ptr
492 * representing a new reference, i.e. the raw pointer retains
493 * ownership.
494 */
reclaim_copy(TTarget * owning_ptr)495 static intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
496 auto ret = reclaim(owning_ptr);
497 ret.retain_();
498 return ret;
499 }
500
501 /**
502 * Allocate a heap object with args and wrap it inside a intrusive_ptr and
503 * incref. This is a helper function to let make_intrusive() access private
504 * intrusive_ptr constructors.
505 */
506 template <class... Args>
make(Args &&...args)507 static intrusive_ptr make(Args&&... args) {
508 return intrusive_ptr(new TTarget(std::forward<Args>(args)...));
509 }
510
511 /**
512 * Turn a new instance of TTarget (e.g., literally allocated
513 * using new TTarget(...) into an intrusive_ptr. If possible,
514 * use intrusive_ptr::make instead which statically guarantees
515 * that the allocation was done properly.
516 *
517 * At the moment, the only reason this method exists is because
518 * pybind11 holder types expect to be able to allocate in
519 * this way (because pybind11 handles the new allocation itself).
520 */
unsafe_steal_from_new(TTarget * raw_ptr)521 static intrusive_ptr unsafe_steal_from_new(TTarget* raw_ptr) {
522 return intrusive_ptr(raw_ptr);
523 }
524
525 /**
526 * Turn an instance of TTarget that should not be reference counted
527 * (e.g., allocated into an arena with placement new) into an
528 * intrusive_ptr. This is gratuitously unsafe and should only be
529 * used if you can guarantee that the pointer will not escape and be
530 * refcounted as normal.
531 *
532 * `expected_decrefs` is a debugging parameter: it indicates the
533 * number of strong owners the intrusive_ptr_target in question is
534 * expected to get. In most use cases, this will likely be 1.
535 *
536 * The reason this method exists is for manually sharing
537 * StorageImpls across Tensors in the static runtime. It needs
538 * access to private intrusive_ptr members so that the refcounts can
539 * be initialized to custom values.
540 */
unsafe_adapt_non_heap_allocated(TTarget * raw_ptr,uint32_t expected_decrefs)541 static intrusive_ptr unsafe_adapt_non_heap_allocated(
542 TTarget* raw_ptr,
543 uint32_t expected_decrefs) {
544 intrusive_ptr result(raw_ptr, raw::DontIncreaseRefcount{});
545 // kImpracticallyHugeReferenceCount is impractically huge for a reference
546 // count, while being in no danger of overflowing uint32_t. We actually only
547 // need to initialize the refcount to 2 -- we are just doing an unbalanced
548 // incref to prevent the non-heap-allocated target from being
549 // freed, and we are optimizing that incref by directly
550 // initializing the refcounts rather than doing an expensive
551 // atomic increment. The reason to use kImpracticallyHugeReferenceCount is
552 // to accommodate the debug assertions in ~intrusive_ptr_target.
553 #ifdef NDEBUG
554 expected_decrefs = 0;
555 #endif
556 result.target_->refcount_.store(
557 detail::kImpracticallyHugeReferenceCount + expected_decrefs,
558 std::memory_order_relaxed);
559 result.target_->weakcount_.store(
560 detail::kImpracticallyHugeReferenceCount, std::memory_order_relaxed);
561 return result;
562 }
563
564 /**
565 * Turn a **non-owning raw pointer** to an intrusive_ptr. It is
566 * the moral equivalent of enable_shared_from_this on a shared pointer.
567 *
568 * This method is only valid for objects that are already live. If
569 * you are looking for the moral equivalent of unique_ptr<T>(T*)
570 * constructor, see steal_from_new.
571 *
572 * TODO: https://github.com/pytorch/pytorch/issues/56482
573 */
unsafe_reclaim_from_nonowning(TTarget * raw_ptr)574 static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) {
575 // See Note [Stack allocated intrusive_ptr_target safety]
576 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
577 raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0,
578 "intrusive_ptr: Can only reclaim pointers that are owned by someone");
579 auto ptr = reclaim(raw_ptr); // doesn't increase refcount
580 ptr.retain_();
581 return ptr;
582 }
583 };
584
585 template <
586 class TTarget,
587 class NullType = detail::intrusive_target_default_null_type<TTarget>,
588 class... Args>
make_intrusive(Args &&...args)589 inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) {
590 return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...);
591 }
592
593 template <class TTarget, class NullType>
swap(intrusive_ptr<TTarget,NullType> & lhs,intrusive_ptr<TTarget,NullType> & rhs)594 inline void swap(
595 intrusive_ptr<TTarget, NullType>& lhs,
596 intrusive_ptr<TTarget, NullType>& rhs) noexcept {
597 lhs.swap(rhs);
598 }
599
600 // To allow intrusive_ptr inside std::map or std::set, we need operator<
601 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
602 inline bool operator<(
603 const intrusive_ptr<TTarget1, NullType1>& lhs,
604 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
605 return lhs.get() < rhs.get();
606 }
607
608 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
609 inline bool operator==(
610 const intrusive_ptr<TTarget1, NullType1>& lhs,
611 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
612 return lhs.get() == rhs.get();
613 }
614
615 template <class TTarget1, class NullType1>
616 inline bool operator==(
617 const intrusive_ptr<TTarget1, NullType1>& lhs,
618 std::nullptr_t) noexcept {
619 return lhs.get() == nullptr;
620 }
621
622 template <class TTarget2, class NullType2>
623 inline bool operator==(
624 std::nullptr_t,
625 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
626 return nullptr == rhs.get();
627 }
628
629 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
630 inline bool operator!=(
631 const intrusive_ptr<TTarget1, NullType1>& lhs,
632 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
633 return !operator==(lhs, rhs);
634 }
635
636 template <class TTarget1, class NullType1>
637 inline bool operator!=(
638 const intrusive_ptr<TTarget1, NullType1>& lhs,
639 std::nullptr_t) noexcept {
640 return !operator==(lhs, nullptr);
641 }
642
643 template <class TTarget2, class NullType2>
644 inline bool operator!=(
645 std::nullptr_t,
646 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
647 return !operator==(nullptr, rhs);
648 }
649 template <typename T>
650 struct MaybeOwnedTraits<c10::intrusive_ptr<T>> {
651 using owned_type = c10::intrusive_ptr<T>;
652 using borrow_type = c10::intrusive_ptr<T>;
653
654 static borrow_type createBorrow(const owned_type& from) {
655 return borrow_type::reclaim(from.get());
656 }
657
658 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
659 lhs.release();
660 lhs = borrow_type::reclaim(rhs.get());
661 }
662
663 static void destroyBorrow(borrow_type& toDestroy) {
664 toDestroy.release();
665 }
666
667 static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
668 return borrow;
669 }
670
671 static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
672 return &borrow;
673 }
674
675 static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
676 return true;
677 }
678 };
679
680 template <
681 typename TTarget,
682 class NullType = detail::intrusive_target_default_null_type<TTarget>>
683 class weak_intrusive_ptr final {
684 private:
685 static_assert(
686 std::is_base_of_v<intrusive_ptr_target, TTarget>,
687 "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target.");
688 #ifndef _WIN32
689 // This static_assert triggers on MSVC
690 // error C2131: expression did not evaluate to a constant
691 static_assert(
692 NullType::singleton() == NullType::singleton(),
693 "NullType must have a constexpr singleton() method");
694 #endif
695 static_assert(
696 std::is_base_of_v<
697 TTarget,
698 std::remove_pointer_t<decltype(NullType::singleton())>>,
699 "NullType::singleton() must return a element_type* pointer");
700
701 TTarget* target_;
702
703 template <class TTarget2, class NullType2>
704 friend class weak_intrusive_ptr;
705
706 void retain_() {
707 if (target_ != NullType::singleton()) {
708 uint32_t new_weakcount =
709 detail::atomic_weakcount_increment(target_->weakcount_);
710 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
711 new_weakcount != 1,
712 "weak_intrusive_ptr: Cannot increase weakcount after it reached zero.");
713 }
714 }
715
716 void reset_() noexcept {
717 if (target_ != NullType::singleton() &&
718 detail::atomic_weakcount_decrement(target_->weakcount_) == 0) {
719 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
720 delete target_;
721 }
722 target_ = NullType::singleton();
723 }
724
725 constexpr explicit weak_intrusive_ptr(TTarget* target) : target_(target) {}
726
727 public:
728 using element_type = TTarget;
729
730 explicit weak_intrusive_ptr(const intrusive_ptr<TTarget, NullType>& ptr)
731 : weak_intrusive_ptr(ptr.get()) {
732 retain_();
733 }
734
735 weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
736 rhs.target_ = NullType::singleton();
737 }
738
739 template <class From, class FromNullType>
740 /* implicit */ weak_intrusive_ptr(
741 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
742 weak_intrusive_ptr<From, FromNullType>&& rhs) noexcept
743 : target_(
744 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
745 static_assert(
746 std::is_convertible<From*, TTarget*>::value,
747 "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type.");
748 rhs.target_ = FromNullType::singleton();
749 }
750
751 weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) {
752 retain_();
753 }
754
755 template <class From, class FromNullType>
756 /* implicit */ weak_intrusive_ptr(
757 const weak_intrusive_ptr<From, FromNullType>& rhs)
758 : target_(
759 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
760 static_assert(
761 std::is_convertible<From*, TTarget*>::value,
762 "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type.");
763 retain_();
764 }
765
766 ~weak_intrusive_ptr() noexcept {
767 reset_();
768 }
769
770 weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept {
771 // NOLINTNEXTLINE(*assign*)
772 return this->template operator= <TTarget, NullType>(std::move(rhs));
773 }
774
775 template <class From, class FromNullType>
776 weak_intrusive_ptr& operator=(
777 weak_intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
778 static_assert(
779 std::is_convertible<From*, TTarget*>::value,
780 "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type.");
781 weak_intrusive_ptr tmp = std::move(rhs);
782 swap(tmp);
783 return *this;
784 }
785
786 weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept {
787 if (this == &rhs) {
788 return *this;
789 }
790 // NOLINTNEXTLINE(*assign*)
791 return this->template operator= <TTarget, NullType>(rhs);
792 }
793
794 weak_intrusive_ptr& operator=(
795 const intrusive_ptr<TTarget, NullType>& rhs) & noexcept {
796 weak_intrusive_ptr tmp(rhs);
797 swap(tmp);
798 return *this;
799 }
800
801 template <class From, class FromNullType>
802 weak_intrusive_ptr& operator=(
803 const weak_intrusive_ptr<From, NullType>& rhs) & noexcept {
804 static_assert(
805 std::is_convertible<From*, TTarget*>::value,
806 "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type.");
807 weak_intrusive_ptr tmp = rhs;
808 swap(tmp);
809 return *this;
810 }
811
812 void reset() noexcept {
813 reset_();
814 }
815
816 void swap(weak_intrusive_ptr& rhs) noexcept {
817 TTarget* tmp = target_;
818 target_ = rhs.target_;
819 rhs.target_ = tmp;
820 }
821
822 // NB: This should ONLY be used by the std::hash implementation
823 // for weak_intrusive_ptr. Another way you could do this is
824 // friend std::hash<weak_intrusive_ptr>, but this triggers two
825 // bugs:
826 //
827 // (1) It triggers an nvcc bug, where std::hash in a friend class
828 // declaration gets preprocessed into hash, which then cannot
829 // actually be found. The error in this case looks like:
830 //
831 // error: no template named 'hash'; did you mean 'std::hash'?
832 //
833 // (2) On OS X, std::hash is declared as a struct, not a class.
834 // This twings:
835 //
836 // error: class 'hash' was previously declared as a struct
837 // [-Werror,-Wmismatched-tags]
838 //
839 // Both of these are work-aroundable, but on the whole, I decided
840 // it would be simpler and easier to make work if we just expose
841 // an unsafe getter for target_
842 //
843 TTarget* _unsafe_get_target() const noexcept {
844 return target_;
845 }
846
847 uint32_t use_count() const noexcept {
848 if (target_ == NullType::singleton()) {
849 return 0;
850 }
851 return target_->refcount_.load(
852 std::memory_order_acquire); // refcount, not weakcount!
853 }
854
855 uint32_t weak_use_count() const noexcept {
856 if (target_ == NullType::singleton()) {
857 return 0;
858 }
859 return target_->weakcount_.load(std::memory_order_acquire);
860 }
861
862 bool expired() const noexcept {
863 return use_count() == 0;
864 }
865
866 intrusive_ptr<TTarget, NullType> lock() const noexcept {
867 if (expired()) {
868 return intrusive_ptr<TTarget, NullType>();
869 } else {
870 auto refcount = target_->refcount_.load(std::memory_order_seq_cst);
871 do {
872 if (refcount == 0) {
873 // Object already destructed, no strong references left anymore.
874 // Return nullptr.
875 return intrusive_ptr<TTarget, NullType>();
876 }
877 } while (
878 !target_->refcount_.compare_exchange_weak(refcount, refcount + 1));
879 return intrusive_ptr<TTarget, NullType>(
880 target_, raw::DontIncreaseRefcount{});
881 }
882 }
883
884 /**
885 * Returns an owning (but still only weakly referenced) pointer to the
886 * underlying object and makes the weak_intrusive_ptr instance invalid.
887 * That means the weakcount is not decreased.
888 * You *must* put the returned pointer back into a weak_intrusive_ptr using
889 * weak_intrusive_ptr::reclaim(ptr) to properly destruct it.
890 * This is helpful for C APIs.
891 */
892 TTarget* release() noexcept {
893 TTarget* result = target_;
894 target_ = NullType::singleton();
895 return result;
896 }
897
898 /**
899 * Takes an owning (but must be weakly referenced) pointer to TTarget* and
900 * creates a weak_intrusive_ptr that takes over ownership.
901 * This means that the weakcount is not increased.
902 * This is the counter-part to weak_intrusive_ptr::release() and the pointer
903 * passed in *must* have been created using weak_intrusive_ptr::release().
904 */
905 static weak_intrusive_ptr reclaim(TTarget* owning_weak_ptr) {
906 // See Note [Stack allocated intrusive_ptr_target safety]
907 // if refcount > 0, weakcount must be >1 for weak references to exist.
908 // see weak counting explanation at top of this file.
909 // if refcount == 0, weakcount only must be >0.
910 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
911 owning_weak_ptr == NullType::singleton() ||
912 owning_weak_ptr->weakcount_.load() > 1 ||
913 (owning_weak_ptr->refcount_.load() == 0 &&
914 owning_weak_ptr->weakcount_.load() > 0),
915 "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release().");
916 return weak_intrusive_ptr(owning_weak_ptr);
917 }
918
919 /**
920 * Takes a pointer to TTarget* (may be weak or strong) and creates a
921 * new weak_intrusive_ptr representing a new weak reference, i.e.
922 * the raw pointer retains ownership.
923 */
924 static weak_intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
925 auto ret = reclaim(owning_ptr);
926 ret.retain_();
927 return ret;
928 }
929
930 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
931 friend bool operator<(
932 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
933 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
934 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
935 friend bool operator==(
936 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
937 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
938 };
939
940 template <class TTarget, class NullType>
941 inline void swap(
942 weak_intrusive_ptr<TTarget, NullType>& lhs,
943 weak_intrusive_ptr<TTarget, NullType>& rhs) noexcept {
944 lhs.swap(rhs);
945 }
946
947 // To allow weak_intrusive_ptr inside std::map or std::set, we need operator<
948 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
949 inline bool operator<(
950 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
951 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
952 return lhs.target_ < rhs.target_;
953 }
954
955 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
956 inline bool operator==(
957 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
958 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
959 return lhs.target_ == rhs.target_;
960 }
961
962 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
963 inline bool operator!=(
964 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
965 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
966 return !operator==(lhs, rhs);
967 }
968
969 // Alias for documentary purposes, to more easily distinguish
970 // weak raw intrusive pointers from intrusive pointers.
971 using weak_intrusive_ptr_target = intrusive_ptr_target;
972
973 // This namespace provides some methods for working with
974 // raw pointers that subclass intrusive_ptr_target. They are not provided
975 // as methods on intrusive_ptr_target, because ideally you would not need these
976 // methods at all (use smart pointers), but if you are dealing with legacy code
977 // that still needs to pass around raw pointers, you may find these quite
978 // useful.
979 //
980 // An important usage note: some functions are only valid if you have a
981 // strong raw pointer to the object, while others are only valid if you
982 // have a weak raw pointer to the object. ONLY call intrusive_ptr namespace
983 // functions on strong pointers, and weak_intrusive_ptr namespace functions
984 // on weak pointers. If you mix it up, you may get an assert failure.
985 namespace raw {
986
987 namespace intrusive_ptr {
988
989 // WARNING: Unlike the reclaim() API, it is NOT valid to pass
990 // NullType::singleton to this function
991 inline void incref(intrusive_ptr_target* self) {
992 if (self) {
993 detail::atomic_refcount_increment(self->refcount_);
994 }
995 }
996
997 // WARNING: Unlike the reclaim() API, it is NOT valid to pass
998 // NullType::singleton to this function
999 inline void decref(intrusive_ptr_target* self) {
1000 // Let it die
1001 c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1002 // NB: Caller still has 'self' pointer, but it's now invalid.
1003 // If you want more safety, used the actual c10::intrusive_ptr class
1004 }
1005
1006 template <typename T>
1007 inline T* make_weak(T* self) {
1008 // NB: 'this' is a strong pointer, but we return a weak pointer
1009 auto ptr = c10::intrusive_ptr<T>::reclaim(self);
1010 c10::weak_intrusive_ptr<T> wptr(ptr);
1011 ptr.release();
1012 return wptr.release();
1013 }
1014
1015 inline uint32_t use_count(intrusive_ptr_target* self) {
1016 auto ptr = c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1017 auto r = ptr.use_count();
1018 ptr.release();
1019 return r;
1020 }
1021
1022 } // namespace intrusive_ptr
1023
1024 namespace weak_intrusive_ptr {
1025
1026 inline void incref(weak_intrusive_ptr_target* self) {
1027 detail::atomic_weakcount_increment(self->weakcount_);
1028 }
1029
1030 inline void decref(weak_intrusive_ptr_target* self) {
1031 // Let it die
1032 c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1033 // NB: You still "have" the 'self' pointer, but it's now invalid.
1034 // If you want more safety, used the actual c10::weak_intrusive_ptr class
1035 }
1036
1037 template <typename T>
1038 inline T* lock(T* self) {
1039 auto wptr = c10::weak_intrusive_ptr<T>::reclaim(self);
1040 auto ptr = wptr.lock();
1041 wptr.release();
1042 return ptr.release();
1043 }
1044
1045 // This gives the STRONG refcount of a WEAK pointer
1046 inline uint32_t use_count(weak_intrusive_ptr_target* self) {
1047 auto wptr = c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1048 auto r = wptr.use_count();
1049 wptr.release();
1050 return r;
1051 }
1052
1053 } // namespace weak_intrusive_ptr
1054
1055 } // namespace raw
1056
1057 } // namespace c10
1058
1059 namespace std {
1060 // To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or
1061 // std::unordered_set, we need std::hash
1062 template <class TTarget, class NullType>
1063 struct hash<c10::intrusive_ptr<TTarget, NullType>> {
1064 size_t operator()(const c10::intrusive_ptr<TTarget, NullType>& x) const {
1065 return std::hash<TTarget*>()(x.get());
1066 }
1067 };
1068 template <class TTarget, class NullType>
1069 struct hash<c10::weak_intrusive_ptr<TTarget, NullType>> {
1070 size_t operator()(const c10::weak_intrusive_ptr<TTarget, NullType>& x) const {
1071 return std::hash<TTarget*>()(x._unsafe_get_target());
1072 }
1073 };
1074 } // namespace std
1075