xref: /aosp_15_r20/external/pytorch/c10/util/MaybeOwned.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/util/Exception.h>
5 
6 #include <memory>
7 #include <type_traits>
8 #include <utility>
9 
10 namespace c10 {
11 
12 /// MaybeOwnedTraits<T> describes how to borrow from T.  Here is how we
13 /// can implement borrowing from an arbitrary type T using a raw
14 /// pointer to const:
15 template <typename T>
16 struct MaybeOwnedTraitsGenericImpl {
17   using owned_type = T;
18   using borrow_type = const T*;
19 
createBorrowMaybeOwnedTraitsGenericImpl20   static borrow_type createBorrow(const owned_type& from) {
21     return &from;
22   }
23 
assignBorrowMaybeOwnedTraitsGenericImpl24   static void assignBorrow(borrow_type& lhs, borrow_type rhs) {
25     lhs = rhs;
26   }
27 
destroyBorrowMaybeOwnedTraitsGenericImpl28   static void destroyBorrow(borrow_type& /*toDestroy*/) {}
29 
referenceFromBorrowMaybeOwnedTraitsGenericImpl30   static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
31     return *borrow;
32   }
33 
pointerFromBorrowMaybeOwnedTraitsGenericImpl34   static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
35     return borrow;
36   }
37 
debugBorrowIsValidMaybeOwnedTraitsGenericImpl38   static bool debugBorrowIsValid(const borrow_type& borrow) {
39     return borrow != nullptr;
40   }
41 };
42 
43 /// It is possible to eliminate the extra layer of indirection for
44 /// borrows for some types that we control. For examples, see
45 /// intrusive_ptr.h and TensorBody.h.
46 
47 template <typename T>
48 struct MaybeOwnedTraits;
49 
50 // Explicitly enable MaybeOwned<shared_ptr<T>>, rather than allowing
51 // MaybeOwned to be used for any type right away.
52 template <typename T>
53 struct MaybeOwnedTraits<std::shared_ptr<T>>
54     : public MaybeOwnedTraitsGenericImpl<std::shared_ptr<T>> {};
55 
56 /// A smart pointer around either a borrowed or owned T. When
57 /// constructed with borrowed(), the caller MUST ensure that the
58 /// borrowed-from argument outlives this MaybeOwned<T>. Compare to
59 /// Rust's std::borrow::Cow
60 /// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note
61 /// that it is probably not suitable for general use because C++ has
62 /// no borrow checking. Included here to support
63 /// Tensor::expect_contiguous.
64 template <typename T>
65 class MaybeOwned final {
66   using borrow_type = typename MaybeOwnedTraits<T>::borrow_type;
67   using owned_type = typename MaybeOwnedTraits<T>::owned_type;
68 
69   bool isBorrowed_;
70   union {
71     borrow_type borrow_;
72     owned_type own_;
73   };
74 
75   /// Don't use this; use borrowed() instead.
76   explicit MaybeOwned(const owned_type& t)
77       : isBorrowed_(true), borrow_(MaybeOwnedTraits<T>::createBorrow(t)) {}
78 
79   /// Don't use this; use owned() instead.
80   explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible_v<T>)
81       : isBorrowed_(false), own_(std::move(t)) {}
82 
83   /// Don't use this; use owned() instead.
84   template <class... Args>
85   explicit MaybeOwned(std::in_place_t, Args&&... args)
86       : isBorrowed_(false), own_(std::forward<Args>(args)...) {}
87 
88  public:
89   explicit MaybeOwned() : isBorrowed_(true), borrow_() {}
90 
91   // Copying a borrow yields another borrow of the original, as with a
92   // T*. Copying an owned T yields another owned T for safety: no
93   // chains of borrowing by default! (Note you could get that behavior
94   // with MaybeOwned<T>::borrowed(*rhs) if you wanted it.)
95   MaybeOwned(const MaybeOwned& rhs) : isBorrowed_(rhs.isBorrowed_) {
96     if (C10_LIKELY(rhs.isBorrowed_)) {
97       MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
98     } else {
99       new (&own_) T(rhs.own_);
100     }
101   }
102 
103   MaybeOwned& operator=(const MaybeOwned& rhs) {
104     if (this == &rhs) {
105       return *this;
106     }
107     if (C10_UNLIKELY(!isBorrowed_)) {
108       if (rhs.isBorrowed_) {
109         own_.~T();
110         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
111         isBorrowed_ = true;
112       } else {
113         own_ = rhs.own_;
114       }
115     } else {
116       if (C10_LIKELY(rhs.isBorrowed_)) {
117         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
118       } else {
119         MaybeOwnedTraits<T>::destroyBorrow(borrow_);
120         new (&own_) T(rhs.own_);
121         isBorrowed_ = false;
122       }
123     }
124     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_);
125     return *this;
126   }
127 
128   MaybeOwned(MaybeOwned&& rhs) noexcept(
129       // NOLINTNEXTLINE(*-noexcept-move-*)
130       std::is_nothrow_move_constructible_v<T> &&
131       std::is_nothrow_move_assignable_v<borrow_type>)
132       : isBorrowed_(rhs.isBorrowed_) {
133     if (C10_LIKELY(rhs.isBorrowed_)) {
134       MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
135     } else {
136       new (&own_) T(std::move(rhs.own_));
137     }
138   }
139 
140   MaybeOwned& operator=(MaybeOwned&& rhs) noexcept(
141       std::is_nothrow_move_assignable_v<T> &&
142       std::is_nothrow_move_assignable_v<borrow_type> &&
143       std::is_nothrow_move_constructible_v<T> &&
144       // NOLINTNEXTLINE(*-noexcept-move-*)
145       std::is_nothrow_destructible_v<T> &&
146       std::is_nothrow_destructible_v<borrow_type>) {
147     if (this == &rhs) {
148       return *this;
149     }
150     if (C10_UNLIKELY(!isBorrowed_)) {
151       if (rhs.isBorrowed_) {
152         own_.~T();
153         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
154         isBorrowed_ = true;
155       } else {
156         own_ = std::move(rhs.own_);
157       }
158     } else {
159       if (C10_LIKELY(rhs.isBorrowed_)) {
160         MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
161       } else {
162         MaybeOwnedTraits<T>::destroyBorrow(borrow_);
163         new (&own_) T(std::move(rhs.own_));
164         isBorrowed_ = false;
165       }
166     }
167     return *this;
168   }
169 
170   static MaybeOwned borrowed(const T& t) {
171     return MaybeOwned(t);
172   }
173 
174   static MaybeOwned owned(T&& t) noexcept(
175       std::is_nothrow_move_constructible_v<T>) {
176     return MaybeOwned(std::move(t));
177   }
178 
179   template <class... Args>
180   static MaybeOwned owned(std::in_place_t, Args&&... args) {
181     return MaybeOwned(std::in_place, std::forward<Args>(args)...);
182   }
183 
184   ~MaybeOwned() noexcept(
185       // NOLINTNEXTLINE(*-noexcept-destructor)
186       std::is_nothrow_destructible_v<T> &&
187       std::is_nothrow_destructible_v<borrow_type>) {
188     if (C10_UNLIKELY(!isBorrowed_)) {
189       own_.~T();
190     } else {
191       MaybeOwnedTraits<T>::destroyBorrow(borrow_);
192     }
193   }
194 
195   // This is an implementation detail!  You should know what you're doing
196   // if you are testing this.  If you just want to guarantee ownership move
197   // this into a T
198   bool unsafeIsBorrowed() const {
199     return isBorrowed_;
200   }
201 
202   const T& operator*() const& {
203     if (isBorrowed_) {
204       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
205           MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
206     }
207     return C10_LIKELY(isBorrowed_)
208         ? MaybeOwnedTraits<T>::referenceFromBorrow(borrow_)
209         : own_;
210   }
211 
212   const T* operator->() const {
213     if (isBorrowed_) {
214       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
215           MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
216     }
217     return C10_LIKELY(isBorrowed_)
218         ? MaybeOwnedTraits<T>::pointerFromBorrow(borrow_)
219         : &own_;
220   }
221 
222   // If borrowed, copy the underlying T. If owned, move from
223   // it. borrowed/owned state remains the same, and either we
224   // reference the same borrow as before or we are an owned moved-from
225   // T.
226   T operator*() && {
227     if (isBorrowed_) {
228       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
229           MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
230       return MaybeOwnedTraits<T>::referenceFromBorrow(borrow_);
231     } else {
232       return std::move(own_);
233     }
234   }
235 };
236 
237 } // namespace c10
238