1 #pragma once 2 3 #include <c10/core/Allocator.h> 4 #include <c10/core/Device.h> 5 #include <c10/core/DeviceType.h> 6 #include <c10/core/StorageImpl.h> 7 #include <c10/core/SymInt.h> 8 #include <c10/macros/Export.h> 9 #include <c10/util/Exception.h> 10 #include <c10/util/ExclusivelyOwned.h> 11 #include <c10/util/MaybeOwned.h> 12 #include <c10/util/UniqueVoidPtr.h> 13 #include <c10/util/intrusive_ptr.h> 14 #include <cstddef> 15 #include <utility> 16 17 namespace c10 { 18 19 struct Storage; 20 21 C10_API bool isSharedStorageAlias( 22 const Storage& storage0, 23 const Storage& storage1); 24 25 struct C10_API Storage { 26 public: 27 struct use_byte_size_t {}; 28 struct unsafe_borrow_t { 29 explicit unsafe_borrow_t() = default; 30 }; 31 32 Storage() = default; StorageStorage33 Storage(c10::intrusive_ptr<StorageImpl> ptr) 34 : storage_impl_(std::move(ptr)) {} 35 36 // Allocates memory buffer using given allocator and creates a storage with it 37 Storage( 38 use_byte_size_t /*use_byte_size*/, 39 const SymInt& size_bytes, 40 Allocator* allocator = nullptr, 41 bool resizable = false) storage_impl_Storage42 : storage_impl_(c10::make_intrusive<StorageImpl>( 43 StorageImpl::use_byte_size_t(), 44 size_bytes, 45 allocator, 46 resizable)) {} 47 48 // Creates storage with pre-allocated memory buffer. Allocator is given for 49 // potential future reallocations, however it can be nullptr if the storage 50 // is non-resizable 51 Storage( 52 use_byte_size_t /*use_byte_size*/, 53 size_t size_bytes, 54 at::DataPtr data_ptr, 55 at::Allocator* allocator = nullptr, 56 bool resizable = false) storage_impl_Storage57 : storage_impl_(c10::make_intrusive<StorageImpl>( 58 StorageImpl::use_byte_size_t(), 59 size_bytes, 60 std::move(data_ptr), 61 allocator, 62 resizable)) {} 63 64 protected: StorageStorage65 explicit Storage(unsafe_borrow_t, const Storage& rhs) 66 : storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim( 67 rhs.storage_impl_.get())) {} 68 69 friend MaybeOwnedTraits<Storage>; 70 71 public: 72 // Legacy constructor for partially initialized (dtype or memory) storages 73 // that can be temporarily created with Caffe2 APIs. See the note on top of 74 // TensorImpl.h for details. create_legacyStorage75 static Storage create_legacy(at::Device device) { 76 auto allocator = GetAllocator(device.type()); 77 return Storage(c10::make_intrusive<StorageImpl>( 78 StorageImpl::use_byte_size_t(), 79 0, 80 allocator->allocate(0), // materialize a non-default Device. 81 allocator, 82 true)); 83 } 84 85 // Mimic create_legacy, but without requiring a newly-created StorageImpl. reset_legacyStorage86 void reset_legacy() { 87 TORCH_CHECK(resizable() && allocator()); 88 set_nbytes(0); 89 set_data_ptr_noswap(allocator()->allocate(0)); 90 } 91 92 // TODO: remove later set_nbytesStorage93 void set_nbytes(size_t size_bytes) const { 94 storage_impl_->set_nbytes(size_bytes); 95 } 96 set_nbytesStorage97 void set_nbytes(c10::SymInt size_bytes) const { 98 storage_impl_->set_nbytes(std::move(size_bytes)); 99 } 100 resizableStorage101 bool resizable() const { 102 return storage_impl_->resizable(); 103 } 104 nbytesStorage105 size_t nbytes() const { 106 return storage_impl_->nbytes(); 107 } 108 sym_nbytesStorage109 SymInt sym_nbytes() const { 110 return storage_impl_->sym_nbytes(); 111 } 112 // get() use here is to get const-correctness 113 dataStorage114 const void* data() const { 115 return storage_impl_->data(); 116 } 117 mutable_dataStorage118 void* mutable_data() const { 119 return storage_impl_->mutable_data(); 120 } 121 mutable_data_ptrStorage122 at::DataPtr& mutable_data_ptr() const { 123 return storage_impl_->mutable_data_ptr(); 124 } 125 data_ptrStorage126 const at::DataPtr& data_ptr() const { 127 return storage_impl_->data_ptr(); 128 } 129 130 // Returns the previous data_ptr set_data_ptrStorage131 at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) const { 132 return storage_impl_->set_data_ptr(std::move(data_ptr)); 133 } 134 set_data_ptr_noswapStorage135 void set_data_ptr_noswap(at::DataPtr&& data_ptr) const { 136 return storage_impl_->set_data_ptr_noswap(std::move(data_ptr)); 137 } 138 device_typeStorage139 DeviceType device_type() const { 140 return storage_impl_->device_type(); 141 } 142 allocatorStorage143 at::Allocator* allocator() const { 144 return storage_impl_->allocator(); 145 } 146 deviceStorage147 at::Device device() const { 148 return storage_impl_->device(); 149 } 150 unsafeReleaseStorageImplStorage151 StorageImpl* unsafeReleaseStorageImpl() { 152 return storage_impl_.release(); 153 } 154 unsafeGetStorageImplStorage155 StorageImpl* unsafeGetStorageImpl() const noexcept { 156 return storage_impl_.get(); 157 } 158 getWeakStorageImplStorage159 c10::weak_intrusive_ptr<StorageImpl> getWeakStorageImpl() const { 160 return c10::weak_intrusive_ptr<StorageImpl>(storage_impl_); 161 } 162 163 operator bool() const { 164 return storage_impl_; 165 } 166 use_countStorage167 size_t use_count() const { 168 return storage_impl_.use_count(); 169 } 170 uniqueStorage171 inline bool unique() const { 172 return storage_impl_.unique(); 173 } 174 is_alias_ofStorage175 bool is_alias_of(const Storage& other) const { 176 return ( 177 storage_impl_ == other.storage_impl_ || 178 isSharedStorageAlias(*this, other)); 179 } 180 181 void UniqueStorageShareExternalPointer( 182 void* src, 183 size_t capacity, 184 DeleterFnPtr d = nullptr) { 185 if (!storage_impl_.unique()) { 186 TORCH_CHECK( 187 false, 188 "UniqueStorageShareExternalPointer can only be called when use_count == 1"); 189 } 190 storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d); 191 } 192 UniqueStorageShareExternalPointerStorage193 void UniqueStorageShareExternalPointer( 194 at::DataPtr&& data_ptr, 195 size_t capacity) { 196 if (!storage_impl_.unique()) { 197 TORCH_CHECK( 198 false, 199 "UniqueStorageShareExternalPointer can only be called when use_count == 1"); 200 } 201 storage_impl_->UniqueStorageShareExternalPointer( 202 std::move(data_ptr), capacity); 203 } 204 205 protected: 206 c10::intrusive_ptr<StorageImpl> storage_impl_; 207 }; 208 209 template <> 210 struct MaybeOwnedTraits<c10::Storage> { 211 using owned_type = c10::Storage; 212 using borrow_type = c10::Storage; 213 214 static borrow_type createBorrow(const owned_type& from) { 215 return borrow_type(borrow_type::unsafe_borrow_t{}, from); 216 } 217 218 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { 219 lhs.unsafeReleaseStorageImpl(); 220 lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); 221 } 222 223 static void destroyBorrow(borrow_type& toDestroy) { 224 toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0. 225 } 226 227 static const owned_type& referenceFromBorrow(const borrow_type& borrow) { 228 return borrow; 229 } 230 231 static const owned_type* pointerFromBorrow(const borrow_type& borrow) { 232 return &borrow; 233 } 234 235 static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { 236 return true; 237 } 238 }; 239 240 template <> 241 struct ExclusivelyOwnedTraits<c10::Storage> { 242 using repr_type = c10::Storage; 243 using pointer_type = c10::Storage*; 244 using const_pointer_type = const c10::Storage*; 245 246 static repr_type nullRepr() { 247 return c10::Storage(); 248 } 249 250 template <class... Args> 251 static repr_type createInPlace(Args&&... args) { 252 return c10::Storage(std::forward<Args>(args)...); 253 } 254 255 static repr_type moveToRepr(c10::Storage&& x) { 256 return std::move(x); 257 } 258 259 static c10::Storage take(c10::Storage& x) { 260 return std::move(x); 261 } 262 263 static pointer_type getImpl(repr_type& x) { 264 return &x; 265 } 266 267 static const_pointer_type getImpl(const repr_type& x) { 268 return &x; 269 } 270 }; 271 272 } // namespace c10 273