xref: /aosp_15_r20/external/pytorch/c10/core/Storage.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/core/Device.h>
5 #include <c10/core/DeviceType.h>
6 #include <c10/core/StorageImpl.h>
7 #include <c10/core/SymInt.h>
8 #include <c10/macros/Export.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/ExclusivelyOwned.h>
11 #include <c10/util/MaybeOwned.h>
12 #include <c10/util/UniqueVoidPtr.h>
13 #include <c10/util/intrusive_ptr.h>
14 #include <cstddef>
15 #include <utility>
16 
17 namespace c10 {
18 
19 struct Storage;
20 
21 C10_API bool isSharedStorageAlias(
22     const Storage& storage0,
23     const Storage& storage1);
24 
25 struct C10_API Storage {
26  public:
27   struct use_byte_size_t {};
28   struct unsafe_borrow_t {
29     explicit unsafe_borrow_t() = default;
30   };
31 
32   Storage() = default;
StorageStorage33   Storage(c10::intrusive_ptr<StorageImpl> ptr)
34       : storage_impl_(std::move(ptr)) {}
35 
36   // Allocates memory buffer using given allocator and creates a storage with it
37   Storage(
38       use_byte_size_t /*use_byte_size*/,
39       const SymInt& size_bytes,
40       Allocator* allocator = nullptr,
41       bool resizable = false)
storage_impl_Storage42       : storage_impl_(c10::make_intrusive<StorageImpl>(
43             StorageImpl::use_byte_size_t(),
44             size_bytes,
45             allocator,
46             resizable)) {}
47 
48   // Creates storage with pre-allocated memory buffer. Allocator is given for
49   // potential future reallocations, however it can be nullptr if the storage
50   // is non-resizable
51   Storage(
52       use_byte_size_t /*use_byte_size*/,
53       size_t size_bytes,
54       at::DataPtr data_ptr,
55       at::Allocator* allocator = nullptr,
56       bool resizable = false)
storage_impl_Storage57       : storage_impl_(c10::make_intrusive<StorageImpl>(
58             StorageImpl::use_byte_size_t(),
59             size_bytes,
60             std::move(data_ptr),
61             allocator,
62             resizable)) {}
63 
64  protected:
StorageStorage65   explicit Storage(unsafe_borrow_t, const Storage& rhs)
66       : storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
67             rhs.storage_impl_.get())) {}
68 
69   friend MaybeOwnedTraits<Storage>;
70 
71  public:
72   // Legacy constructor for partially initialized (dtype or memory) storages
73   // that can be temporarily created with Caffe2 APIs. See the note on top of
74   // TensorImpl.h for details.
create_legacyStorage75   static Storage create_legacy(at::Device device) {
76     auto allocator = GetAllocator(device.type());
77     return Storage(c10::make_intrusive<StorageImpl>(
78         StorageImpl::use_byte_size_t(),
79         0,
80         allocator->allocate(0), // materialize a non-default Device.
81         allocator,
82         true));
83   }
84 
85   // Mimic create_legacy, but without requiring a newly-created StorageImpl.
reset_legacyStorage86   void reset_legacy() {
87     TORCH_CHECK(resizable() && allocator());
88     set_nbytes(0);
89     set_data_ptr_noswap(allocator()->allocate(0));
90   }
91 
92   // TODO: remove later
set_nbytesStorage93   void set_nbytes(size_t size_bytes) const {
94     storage_impl_->set_nbytes(size_bytes);
95   }
96 
set_nbytesStorage97   void set_nbytes(c10::SymInt size_bytes) const {
98     storage_impl_->set_nbytes(std::move(size_bytes));
99   }
100 
resizableStorage101   bool resizable() const {
102     return storage_impl_->resizable();
103   }
104 
nbytesStorage105   size_t nbytes() const {
106     return storage_impl_->nbytes();
107   }
108 
sym_nbytesStorage109   SymInt sym_nbytes() const {
110     return storage_impl_->sym_nbytes();
111   }
112   // get() use here is to get const-correctness
113 
dataStorage114   const void* data() const {
115     return storage_impl_->data();
116   }
117 
mutable_dataStorage118   void* mutable_data() const {
119     return storage_impl_->mutable_data();
120   }
121 
mutable_data_ptrStorage122   at::DataPtr& mutable_data_ptr() const {
123     return storage_impl_->mutable_data_ptr();
124   }
125 
data_ptrStorage126   const at::DataPtr& data_ptr() const {
127     return storage_impl_->data_ptr();
128   }
129 
130   // Returns the previous data_ptr
set_data_ptrStorage131   at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) const {
132     return storage_impl_->set_data_ptr(std::move(data_ptr));
133   }
134 
set_data_ptr_noswapStorage135   void set_data_ptr_noswap(at::DataPtr&& data_ptr) const {
136     return storage_impl_->set_data_ptr_noswap(std::move(data_ptr));
137   }
138 
device_typeStorage139   DeviceType device_type() const {
140     return storage_impl_->device_type();
141   }
142 
allocatorStorage143   at::Allocator* allocator() const {
144     return storage_impl_->allocator();
145   }
146 
deviceStorage147   at::Device device() const {
148     return storage_impl_->device();
149   }
150 
unsafeReleaseStorageImplStorage151   StorageImpl* unsafeReleaseStorageImpl() {
152     return storage_impl_.release();
153   }
154 
unsafeGetStorageImplStorage155   StorageImpl* unsafeGetStorageImpl() const noexcept {
156     return storage_impl_.get();
157   }
158 
getWeakStorageImplStorage159   c10::weak_intrusive_ptr<StorageImpl> getWeakStorageImpl() const {
160     return c10::weak_intrusive_ptr<StorageImpl>(storage_impl_);
161   }
162 
163   operator bool() const {
164     return storage_impl_;
165   }
166 
use_countStorage167   size_t use_count() const {
168     return storage_impl_.use_count();
169   }
170 
uniqueStorage171   inline bool unique() const {
172     return storage_impl_.unique();
173   }
174 
is_alias_ofStorage175   bool is_alias_of(const Storage& other) const {
176     return (
177         storage_impl_ == other.storage_impl_ ||
178         isSharedStorageAlias(*this, other));
179   }
180 
181   void UniqueStorageShareExternalPointer(
182       void* src,
183       size_t capacity,
184       DeleterFnPtr d = nullptr) {
185     if (!storage_impl_.unique()) {
186       TORCH_CHECK(
187           false,
188           "UniqueStorageShareExternalPointer can only be called when use_count == 1");
189     }
190     storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d);
191   }
192 
UniqueStorageShareExternalPointerStorage193   void UniqueStorageShareExternalPointer(
194       at::DataPtr&& data_ptr,
195       size_t capacity) {
196     if (!storage_impl_.unique()) {
197       TORCH_CHECK(
198           false,
199           "UniqueStorageShareExternalPointer can only be called when use_count == 1");
200     }
201     storage_impl_->UniqueStorageShareExternalPointer(
202         std::move(data_ptr), capacity);
203   }
204 
205  protected:
206   c10::intrusive_ptr<StorageImpl> storage_impl_;
207 };
208 
209 template <>
210 struct MaybeOwnedTraits<c10::Storage> {
211   using owned_type = c10::Storage;
212   using borrow_type = c10::Storage;
213 
214   static borrow_type createBorrow(const owned_type& from) {
215     return borrow_type(borrow_type::unsafe_borrow_t{}, from);
216   }
217 
218   static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
219     lhs.unsafeReleaseStorageImpl();
220     lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
221   }
222 
223   static void destroyBorrow(borrow_type& toDestroy) {
224     toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
225   }
226 
227   static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
228     return borrow;
229   }
230 
231   static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
232     return &borrow;
233   }
234 
235   static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
236     return true;
237   }
238 };
239 
240 template <>
241 struct ExclusivelyOwnedTraits<c10::Storage> {
242   using repr_type = c10::Storage;
243   using pointer_type = c10::Storage*;
244   using const_pointer_type = const c10::Storage*;
245 
246   static repr_type nullRepr() {
247     return c10::Storage();
248   }
249 
250   template <class... Args>
251   static repr_type createInPlace(Args&&... args) {
252     return c10::Storage(std::forward<Args>(args)...);
253   }
254 
255   static repr_type moveToRepr(c10::Storage&& x) {
256     return std::move(x);
257   }
258 
259   static c10::Storage take(c10::Storage& x) {
260     return std::move(x);
261   }
262 
263   static pointer_type getImpl(repr_type& x) {
264     return &x;
265   }
266 
267   static const_pointer_type getImpl(const repr_type& x) {
268     return &x;
269   }
270 };
271 
272 } // namespace c10
273