xref: /aosp_15_r20/external/pytorch/c10/util/MaybeOwned.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <memory>
7*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
8*da0073e9SAndroid Build Coastguard Worker #include <utility>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker namespace c10 {
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker /// MaybeOwnedTraits<T> describes how to borrow from T.  Here is how we
13*da0073e9SAndroid Build Coastguard Worker /// can implement borrowing from an arbitrary type T using a raw
14*da0073e9SAndroid Build Coastguard Worker /// pointer to const:
15*da0073e9SAndroid Build Coastguard Worker template <typename T>
16*da0073e9SAndroid Build Coastguard Worker struct MaybeOwnedTraitsGenericImpl {
17*da0073e9SAndroid Build Coastguard Worker   using owned_type = T;
18*da0073e9SAndroid Build Coastguard Worker   using borrow_type = const T*;
19*da0073e9SAndroid Build Coastguard Worker 
createBorrowMaybeOwnedTraitsGenericImpl20*da0073e9SAndroid Build Coastguard Worker   static borrow_type createBorrow(const owned_type& from) {
21*da0073e9SAndroid Build Coastguard Worker     return &from;
22*da0073e9SAndroid Build Coastguard Worker   }
23*da0073e9SAndroid Build Coastguard Worker 
assignBorrowMaybeOwnedTraitsGenericImpl24*da0073e9SAndroid Build Coastguard Worker   static void assignBorrow(borrow_type& lhs, borrow_type rhs) {
25*da0073e9SAndroid Build Coastguard Worker     lhs = rhs;
26*da0073e9SAndroid Build Coastguard Worker   }
27*da0073e9SAndroid Build Coastguard Worker 
destroyBorrowMaybeOwnedTraitsGenericImpl28*da0073e9SAndroid Build Coastguard Worker   static void destroyBorrow(borrow_type& /*toDestroy*/) {}
29*da0073e9SAndroid Build Coastguard Worker 
referenceFromBorrowMaybeOwnedTraitsGenericImpl30*da0073e9SAndroid Build Coastguard Worker   static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
31*da0073e9SAndroid Build Coastguard Worker     return *borrow;
32*da0073e9SAndroid Build Coastguard Worker   }
33*da0073e9SAndroid Build Coastguard Worker 
pointerFromBorrowMaybeOwnedTraitsGenericImpl34*da0073e9SAndroid Build Coastguard Worker   static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
35*da0073e9SAndroid Build Coastguard Worker     return borrow;
36*da0073e9SAndroid Build Coastguard Worker   }
37*da0073e9SAndroid Build Coastguard Worker 
debugBorrowIsValidMaybeOwnedTraitsGenericImpl38*da0073e9SAndroid Build Coastguard Worker   static bool debugBorrowIsValid(const borrow_type& borrow) {
39*da0073e9SAndroid Build Coastguard Worker     return borrow != nullptr;
40*da0073e9SAndroid Build Coastguard Worker   }
41*da0073e9SAndroid Build Coastguard Worker };
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker /// It is possible to eliminate the extra layer of indirection for
44*da0073e9SAndroid Build Coastguard Worker /// borrows for some types that we control. For examples, see
45*da0073e9SAndroid Build Coastguard Worker /// intrusive_ptr.h and TensorBody.h.
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker template <typename T>
48*da0073e9SAndroid Build Coastguard Worker struct MaybeOwnedTraits;
49*da0073e9SAndroid Build Coastguard Worker 
50*da0073e9SAndroid Build Coastguard Worker // Explicitly enable MaybeOwned<shared_ptr<T>>, rather than allowing
51*da0073e9SAndroid Build Coastguard Worker // MaybeOwned to be used for any type right away.
52*da0073e9SAndroid Build Coastguard Worker template <typename T>
53*da0073e9SAndroid Build Coastguard Worker struct MaybeOwnedTraits<std::shared_ptr<T>>
54*da0073e9SAndroid Build Coastguard Worker     : public MaybeOwnedTraitsGenericImpl<std::shared_ptr<T>> {};
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker /// A smart pointer around either a borrowed or owned T. When
57*da0073e9SAndroid Build Coastguard Worker /// constructed with borrowed(), the caller MUST ensure that the
58*da0073e9SAndroid Build Coastguard Worker /// borrowed-from argument outlives this MaybeOwned<T>. Compare to
59*da0073e9SAndroid Build Coastguard Worker /// Rust's std::borrow::Cow
60*da0073e9SAndroid Build Coastguard Worker /// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note
61*da0073e9SAndroid Build Coastguard Worker /// that it is probably not suitable for general use because C++ has
62*da0073e9SAndroid Build Coastguard Worker /// no borrow checking. Included here to support
63*da0073e9SAndroid Build Coastguard Worker /// Tensor::expect_contiguous.
64*da0073e9SAndroid Build Coastguard Worker template <typename T>
65*da0073e9SAndroid Build Coastguard Worker class MaybeOwned final {
66*da0073e9SAndroid Build Coastguard Worker   using borrow_type = typename MaybeOwnedTraits<T>::borrow_type;
67*da0073e9SAndroid Build Coastguard Worker   using owned_type = typename MaybeOwnedTraits<T>::owned_type;
68*da0073e9SAndroid Build Coastguard Worker 
69*da0073e9SAndroid Build Coastguard Worker   bool isBorrowed_;
70*da0073e9SAndroid Build Coastguard Worker   union {
71*da0073e9SAndroid Build Coastguard Worker     borrow_type borrow_;
72*da0073e9SAndroid Build Coastguard Worker     owned_type own_;
73*da0073e9SAndroid Build Coastguard Worker   };
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker   /// Don't use this; use borrowed() instead.
76*da0073e9SAndroid Build Coastguard Worker   explicit MaybeOwned(const owned_type& t)
77*da0073e9SAndroid Build Coastguard Worker       : isBorrowed_(true), borrow_(MaybeOwnedTraits<T>::createBorrow(t)) {}
78*da0073e9SAndroid Build Coastguard Worker 
79*da0073e9SAndroid Build Coastguard Worker   /// Don't use this; use owned() instead.
80*da0073e9SAndroid Build Coastguard Worker   explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible_v<T>)
81*da0073e9SAndroid Build Coastguard Worker       : isBorrowed_(false), own_(std::move(t)) {}
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   /// Don't use this; use owned() instead.
84*da0073e9SAndroid Build Coastguard Worker   template <class... Args>
85*da0073e9SAndroid Build Coastguard Worker   explicit MaybeOwned(std::in_place_t, Args&&... args)
86*da0073e9SAndroid Build Coastguard Worker       : isBorrowed_(false), own_(std::forward<Args>(args)...) {}
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker  public:
89*da0073e9SAndroid Build Coastguard Worker   explicit MaybeOwned() : isBorrowed_(true), borrow_() {}
90*da0073e9SAndroid Build Coastguard Worker 
91*da0073e9SAndroid Build Coastguard Worker   // Copying a borrow yields another borrow of the original, as with a
92*da0073e9SAndroid Build Coastguard Worker   // T*. Copying an owned T yields another owned T for safety: no
93*da0073e9SAndroid Build Coastguard Worker   // chains of borrowing by default! (Note you could get that behavior
94*da0073e9SAndroid Build Coastguard Worker   // with MaybeOwned<T>::borrowed(*rhs) if you wanted it.)
95*da0073e9SAndroid Build Coastguard Worker   MaybeOwned(const MaybeOwned& rhs) : isBorrowed_(rhs.isBorrowed_) {
96*da0073e9SAndroid Build Coastguard Worker     if (C10_LIKELY(rhs.isBorrowed_)) {
97*da0073e9SAndroid Build Coastguard Worker       MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
98*da0073e9SAndroid Build Coastguard Worker     } else {
99*da0073e9SAndroid Build Coastguard Worker       new (&own_) T(rhs.own_);
100*da0073e9SAndroid Build Coastguard Worker     }
101*da0073e9SAndroid Build Coastguard Worker   }
102*da0073e9SAndroid Build Coastguard Worker 
103*da0073e9SAndroid Build Coastguard Worker   MaybeOwned& operator=(const MaybeOwned& rhs) {
104*da0073e9SAndroid Build Coastguard Worker     if (this == &rhs) {
105*da0073e9SAndroid Build Coastguard Worker       return *this;
106*da0073e9SAndroid Build Coastguard Worker     }
107*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(!isBorrowed_)) {
108*da0073e9SAndroid Build Coastguard Worker       if (rhs.isBorrowed_) {
109*da0073e9SAndroid Build Coastguard Worker         own_.~T();
110*da0073e9SAndroid Build Coastguard Worker         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
111*da0073e9SAndroid Build Coastguard Worker         isBorrowed_ = true;
112*da0073e9SAndroid Build Coastguard Worker       } else {
113*da0073e9SAndroid Build Coastguard Worker         own_ = rhs.own_;
114*da0073e9SAndroid Build Coastguard Worker       }
115*da0073e9SAndroid Build Coastguard Worker     } else {
116*da0073e9SAndroid Build Coastguard Worker       if (C10_LIKELY(rhs.isBorrowed_)) {
117*da0073e9SAndroid Build Coastguard Worker         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
118*da0073e9SAndroid Build Coastguard Worker       } else {
119*da0073e9SAndroid Build Coastguard Worker         MaybeOwnedTraits<T>::destroyBorrow(borrow_);
120*da0073e9SAndroid Build Coastguard Worker         new (&own_) T(rhs.own_);
121*da0073e9SAndroid Build Coastguard Worker         isBorrowed_ = false;
122*da0073e9SAndroid Build Coastguard Worker       }
123*da0073e9SAndroid Build Coastguard Worker     }
124*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_);
125*da0073e9SAndroid Build Coastguard Worker     return *this;
126*da0073e9SAndroid Build Coastguard Worker   }
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker   MaybeOwned(MaybeOwned&& rhs) noexcept(
129*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(*-noexcept-move-*)
130*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_move_constructible_v<T> &&
131*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_move_assignable_v<borrow_type>)
132*da0073e9SAndroid Build Coastguard Worker       : isBorrowed_(rhs.isBorrowed_) {
133*da0073e9SAndroid Build Coastguard Worker     if (C10_LIKELY(rhs.isBorrowed_)) {
134*da0073e9SAndroid Build Coastguard Worker       MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
135*da0073e9SAndroid Build Coastguard Worker     } else {
136*da0073e9SAndroid Build Coastguard Worker       new (&own_) T(std::move(rhs.own_));
137*da0073e9SAndroid Build Coastguard Worker     }
138*da0073e9SAndroid Build Coastguard Worker   }
139*da0073e9SAndroid Build Coastguard Worker 
140*da0073e9SAndroid Build Coastguard Worker   MaybeOwned& operator=(MaybeOwned&& rhs) noexcept(
141*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_move_assignable_v<T> &&
142*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_move_assignable_v<borrow_type> &&
143*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_move_constructible_v<T> &&
144*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(*-noexcept-move-*)
145*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_destructible_v<T> &&
146*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_destructible_v<borrow_type>) {
147*da0073e9SAndroid Build Coastguard Worker     if (this == &rhs) {
148*da0073e9SAndroid Build Coastguard Worker       return *this;
149*da0073e9SAndroid Build Coastguard Worker     }
150*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(!isBorrowed_)) {
151*da0073e9SAndroid Build Coastguard Worker       if (rhs.isBorrowed_) {
152*da0073e9SAndroid Build Coastguard Worker         own_.~T();
153*da0073e9SAndroid Build Coastguard Worker         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
154*da0073e9SAndroid Build Coastguard Worker         isBorrowed_ = true;
155*da0073e9SAndroid Build Coastguard Worker       } else {
156*da0073e9SAndroid Build Coastguard Worker         own_ = std::move(rhs.own_);
157*da0073e9SAndroid Build Coastguard Worker       }
158*da0073e9SAndroid Build Coastguard Worker     } else {
159*da0073e9SAndroid Build Coastguard Worker       if (C10_LIKELY(rhs.isBorrowed_)) {
160*da0073e9SAndroid Build Coastguard Worker         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
161*da0073e9SAndroid Build Coastguard Worker       } else {
162*da0073e9SAndroid Build Coastguard Worker         MaybeOwnedTraits<T>::destroyBorrow(borrow_);
163*da0073e9SAndroid Build Coastguard Worker         new (&own_) T(std::move(rhs.own_));
164*da0073e9SAndroid Build Coastguard Worker         isBorrowed_ = false;
165*da0073e9SAndroid Build Coastguard Worker       }
166*da0073e9SAndroid Build Coastguard Worker     }
167*da0073e9SAndroid Build Coastguard Worker     return *this;
168*da0073e9SAndroid Build Coastguard Worker   }
169*da0073e9SAndroid Build Coastguard Worker 
170*da0073e9SAndroid Build Coastguard Worker   static MaybeOwned borrowed(const T& t) {
171*da0073e9SAndroid Build Coastguard Worker     return MaybeOwned(t);
172*da0073e9SAndroid Build Coastguard Worker   }
173*da0073e9SAndroid Build Coastguard Worker 
174*da0073e9SAndroid Build Coastguard Worker   static MaybeOwned owned(T&& t) noexcept(
175*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_move_constructible_v<T>) {
176*da0073e9SAndroid Build Coastguard Worker     return MaybeOwned(std::move(t));
177*da0073e9SAndroid Build Coastguard Worker   }
178*da0073e9SAndroid Build Coastguard Worker 
179*da0073e9SAndroid Build Coastguard Worker   template <class... Args>
180*da0073e9SAndroid Build Coastguard Worker   static MaybeOwned owned(std::in_place_t, Args&&... args) {
181*da0073e9SAndroid Build Coastguard Worker     return MaybeOwned(std::in_place, std::forward<Args>(args)...);
182*da0073e9SAndroid Build Coastguard Worker   }
183*da0073e9SAndroid Build Coastguard Worker 
184*da0073e9SAndroid Build Coastguard Worker   ~MaybeOwned() noexcept(
185*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(*-noexcept-destructor)
186*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_destructible_v<T> &&
187*da0073e9SAndroid Build Coastguard Worker       std::is_nothrow_destructible_v<borrow_type>) {
188*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(!isBorrowed_)) {
189*da0073e9SAndroid Build Coastguard Worker       own_.~T();
190*da0073e9SAndroid Build Coastguard Worker     } else {
191*da0073e9SAndroid Build Coastguard Worker       MaybeOwnedTraits<T>::destroyBorrow(borrow_);
192*da0073e9SAndroid Build Coastguard Worker     }
193*da0073e9SAndroid Build Coastguard Worker   }
194*da0073e9SAndroid Build Coastguard Worker 
195*da0073e9SAndroid Build Coastguard Worker   // This is an implementation detail!  You should know what you're doing
196*da0073e9SAndroid Build Coastguard Worker   // if you are testing this.  If you just want to guarantee ownership move
197*da0073e9SAndroid Build Coastguard Worker   // this into a T
198*da0073e9SAndroid Build Coastguard Worker   bool unsafeIsBorrowed() const {
199*da0073e9SAndroid Build Coastguard Worker     return isBorrowed_;
200*da0073e9SAndroid Build Coastguard Worker   }
201*da0073e9SAndroid Build Coastguard Worker 
202*da0073e9SAndroid Build Coastguard Worker   const T& operator*() const& {
203*da0073e9SAndroid Build Coastguard Worker     if (isBorrowed_) {
204*da0073e9SAndroid Build Coastguard Worker       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
205*da0073e9SAndroid Build Coastguard Worker           MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
206*da0073e9SAndroid Build Coastguard Worker     }
207*da0073e9SAndroid Build Coastguard Worker     return C10_LIKELY(isBorrowed_)
208*da0073e9SAndroid Build Coastguard Worker         ? MaybeOwnedTraits<T>::referenceFromBorrow(borrow_)
209*da0073e9SAndroid Build Coastguard Worker         : own_;
210*da0073e9SAndroid Build Coastguard Worker   }
211*da0073e9SAndroid Build Coastguard Worker 
212*da0073e9SAndroid Build Coastguard Worker   const T* operator->() const {
213*da0073e9SAndroid Build Coastguard Worker     if (isBorrowed_) {
214*da0073e9SAndroid Build Coastguard Worker       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
215*da0073e9SAndroid Build Coastguard Worker           MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
216*da0073e9SAndroid Build Coastguard Worker     }
217*da0073e9SAndroid Build Coastguard Worker     return C10_LIKELY(isBorrowed_)
218*da0073e9SAndroid Build Coastguard Worker         ? MaybeOwnedTraits<T>::pointerFromBorrow(borrow_)
219*da0073e9SAndroid Build Coastguard Worker         : &own_;
220*da0073e9SAndroid Build Coastguard Worker   }
221*da0073e9SAndroid Build Coastguard Worker 
222*da0073e9SAndroid Build Coastguard Worker   // If borrowed, copy the underlying T. If owned, move from
223*da0073e9SAndroid Build Coastguard Worker   // it. borrowed/owned state remains the same, and either we
224*da0073e9SAndroid Build Coastguard Worker   // reference the same borrow as before or we are an owned moved-from
225*da0073e9SAndroid Build Coastguard Worker   // T.
226*da0073e9SAndroid Build Coastguard Worker   T operator*() && {
227*da0073e9SAndroid Build Coastguard Worker     if (isBorrowed_) {
228*da0073e9SAndroid Build Coastguard Worker       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
229*da0073e9SAndroid Build Coastguard Worker           MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
230*da0073e9SAndroid Build Coastguard Worker       return MaybeOwnedTraits<T>::referenceFromBorrow(borrow_);
231*da0073e9SAndroid Build Coastguard Worker     } else {
232*da0073e9SAndroid Build Coastguard Worker       return std::move(own_);
233*da0073e9SAndroid Build Coastguard Worker     }
234*da0073e9SAndroid Build Coastguard Worker   }
235*da0073e9SAndroid Build Coastguard Worker };
236*da0073e9SAndroid Build Coastguard Worker 
237*da0073e9SAndroid Build Coastguard Worker } // namespace c10
238