1 #pragma once 2 3 #include <c10/core/Allocator.h> 4 #include <c10/core/Device.h> 5 #include <c10/core/DeviceType.h> 6 #include <c10/core/SymInt.h> 7 #include <c10/core/impl/COW.h> 8 #include <c10/core/impl/COWDeleter.h> 9 #include <c10/core/impl/PyObjectSlot.h> 10 #include <c10/macros/Export.h> 11 #include <c10/util/Exception.h> 12 #include <c10/util/UniqueVoidPtr.h> 13 #include <c10/util/intrusive_ptr.h> 14 #include <cstddef> 15 #include <utility> 16 17 namespace c10 { 18 19 C10_API void throwNullDataPtrError(); 20 C10_API void warnDeprecatedDataPtr(); 21 22 // A storage represents the underlying backing data buffer for a 23 // tensor. This concept was inherited from the original Torch7 24 // codebase; we'd kind of like to get rid of the concept 25 // (see https://github.com/pytorch/pytorch/issues/14797) but 26 // it's hard work and no one has gotten around to doing it. 27 // 28 // NB: storage is supposed to uniquely own a data pointer; e.g., 29 // two non-null data pointers alias if and only if they are from 30 // the same storage. Technically you can violate this invariant 31 // (e.g., you can create a non-owning StorageImpl with at::from_blob) 32 // but a lot of things won't work correctly, including: 33 // 34 // - An ordinary deleter on such a storage is wrong, because normal deleters 35 // assume unique ownership, but if you have two storages at the same data, 36 // that implies there is some sort of shared ownership. So your deleter would 37 // have to actually be internally doing some sort of refcount thing 38 // - Deepcopy in Python side relies on storage equality and not data pointer 39 // equality; so if there are two separate storages pointing to the same data, 40 // the data will actually get duplicated in that case (one data ptr before, 41 // two data ptrs after) 42 // - Version counts won't work correctly, because we do all VC tracking at the 43 // level of storages (unless you explicitly disconnect the VC with detach); 44 // mutation because data pointers are the same are totally untracked 45 struct C10_API StorageImpl : public c10::intrusive_ptr_target { 46 public: 47 struct use_byte_size_t {}; 48 StorageImplStorageImpl49 StorageImpl( 50 use_byte_size_t /*use_byte_size*/, 51 SymInt size_bytes, 52 at::DataPtr data_ptr, 53 at::Allocator* allocator, 54 bool resizable) 55 : data_ptr_(std::move(data_ptr)), 56 size_bytes_(std::move(size_bytes)), 57 size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()), 58 resizable_(resizable), 59 received_cuda_(false), 60 allocator_(allocator) { 61 if (resizable) { 62 TORCH_INTERNAL_ASSERT( 63 allocator_, "For resizable storage, allocator must be provided"); 64 } 65 refresh_has_data_ptr_check(); 66 } 67 StorageImplStorageImpl68 StorageImpl( 69 use_byte_size_t /*use_byte_size*/, 70 const SymInt& size_bytes, 71 at::Allocator* allocator, 72 bool resizable) 73 : StorageImpl( 74 use_byte_size_t(), 75 size_bytes, 76 size_bytes.is_heap_allocated() 77 ? allocator->allocate(0) 78 : allocator->allocate(size_bytes.as_int_unchecked()), 79 allocator, 80 resizable) {} 81 82 StorageImpl& operator=(StorageImpl&& other) = delete; 83 StorageImpl& operator=(const StorageImpl&) = delete; 84 StorageImpl() = delete; 85 StorageImpl(StorageImpl&& other) = delete; 86 StorageImpl(const StorageImpl&) = delete; 87 ~StorageImpl() override = default; 88 resetStorageImpl89 void reset() { 90 data_ptr_.clear(); 91 size_bytes_ = 0; 92 size_bytes_is_heap_allocated_ = false; 93 } 94 95 // Destructor doesn't call release_resources because it's 96 // unnecessary; don't forget to change that if needed! release_resourcesStorageImpl97 void release_resources() override { 98 data_ptr_.clear(); 99 } 100 nbytesStorageImpl101 size_t nbytes() const { 102 // OK to do this instead of maybe_as_int as nbytes is guaranteed positive 103 TORCH_CHECK(!size_bytes_is_heap_allocated_); 104 return size_bytes_.as_int_unchecked(); 105 } 106 sym_nbytesStorageImpl107 SymInt sym_nbytes() const { 108 return size_bytes_; 109 } 110 111 // TODO: remove later set_nbytesStorageImpl112 void set_nbytes(size_t size_bytes) { 113 size_bytes_ = static_cast<int64_t>(size_bytes); 114 size_bytes_is_heap_allocated_ = false; 115 } 116 set_nbytesStorageImpl117 void set_nbytes(c10::SymInt size_bytes) { 118 size_bytes_ = std::move(size_bytes); 119 } 120 resizableStorageImpl121 bool resizable() const { 122 return resizable_; 123 } 124 data_ptrStorageImpl125 const at::DataPtr& data_ptr() const { 126 return data_ptr_; 127 } 128 mutable_data_ptrStorageImpl129 at::DataPtr& mutable_data_ptr() { 130 if (C10_UNLIKELY(has_data_ptr_check_)) { 131 if (throw_on_mutable_data_ptr_) { 132 throwNullDataPtrError(); 133 } 134 if (warn_deprecated_on_mutable_data_ptr_) { 135 warnDeprecatedDataPtr(); 136 } 137 maybe_materialize_cow(); 138 } 139 return data_ptr_; 140 } 141 142 // Returns the data_ptr. Bypasses all checks. _mutable_data_ptr_no_checksStorageImpl143 at::DataPtr& _mutable_data_ptr_no_checks() { 144 return data_ptr_; 145 } 146 147 // Returns the previous data_ptr set_data_ptrStorageImpl148 at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) { 149 // We need to materialize the old COW DataPtr because it is 150 // being returned as mutable. 151 maybe_materialize_cow(); 152 return set_data_ptr_no_materialize_cow(std::move(data_ptr)); 153 } 154 set_data_ptr_noswapStorageImpl155 void set_data_ptr_noswap(at::DataPtr&& data_ptr) { 156 data_ptr_ = std::move(data_ptr); 157 refresh_has_data_ptr_check(); 158 } 159 dataStorageImpl160 const void* data() const { 161 return data_ptr_.get(); 162 } 163 mutable_dataStorageImpl164 void* mutable_data() { 165 if (C10_UNLIKELY(has_data_ptr_check_)) { 166 if (throw_on_mutable_data_ptr_) { 167 throwNullDataPtrError(); 168 } 169 if (warn_deprecated_on_mutable_data_ptr_) { 170 warnDeprecatedDataPtr(); 171 } 172 maybe_materialize_cow(); 173 } 174 return data_ptr_.mutable_get(); 175 } 176 device_typeStorageImpl177 at::DeviceType device_type() const { 178 return data_ptr_.device().type(); 179 } 180 allocatorStorageImpl181 at::Allocator* allocator() { 182 return allocator_; 183 } 184 allocatorStorageImpl185 const at::Allocator* allocator() const { 186 return allocator_; 187 } 188 189 // You generally shouldn't use this method, but it is occasionally 190 // useful if you want to override how a tensor will be reallocated, 191 // after it was already allocated (and its initial allocator was 192 // set) set_allocatorStorageImpl193 void set_allocator(at::Allocator* allocator) { 194 allocator_ = allocator; 195 } 196 deviceStorageImpl197 Device device() const { 198 return data_ptr_.device(); 199 } 200 set_resizableStorageImpl201 void set_resizable(bool resizable) { 202 if (resizable) { 203 // We need an allocator to be resizable 204 AT_ASSERT(allocator_); 205 } 206 resizable_ = resizable; 207 } 208 209 /** 210 * Can only be called when use_count is 1 211 */ 212 void UniqueStorageShareExternalPointer( 213 void* src, 214 size_t size_bytes, 215 DeleterFnPtr d = nullptr) { 216 UniqueStorageShareExternalPointer( 217 at::DataPtr(src, src, d, data_ptr_.device()), size_bytes); 218 } 219 220 /** 221 * Can only be called when use_count is 1 222 */ UniqueStorageShareExternalPointerStorageImpl223 void UniqueStorageShareExternalPointer( 224 at::DataPtr&& data_ptr, 225 size_t size_bytes) { 226 data_ptr_ = std::move(data_ptr); 227 size_bytes_ = static_cast<int64_t>(size_bytes); 228 size_bytes_is_heap_allocated_ = false; 229 allocator_ = nullptr; 230 resizable_ = false; 231 } 232 233 // This method can be used only after storage construction and cannot be used 234 // to modify storage status set_received_cudaStorageImpl235 void set_received_cuda(bool received_cuda) { 236 received_cuda_ = received_cuda; 237 } 238 received_cudaStorageImpl239 bool received_cuda() { 240 return received_cuda_; 241 } 242 pyobj_slotStorageImpl243 impl::PyObjectSlot* pyobj_slot() { 244 return &pyobj_slot_; 245 } 246 pyobj_slotStorageImpl247 const impl::PyObjectSlot* pyobj_slot() const { 248 return &pyobj_slot_; 249 } 250 set_throw_on_mutable_data_ptrStorageImpl251 void set_throw_on_mutable_data_ptr() { 252 throw_on_mutable_data_ptr_ = true; 253 refresh_has_data_ptr_check(); 254 } 255 set_warn_deprecated_on_mutable_data_ptrStorageImpl256 void set_warn_deprecated_on_mutable_data_ptr() { 257 warn_deprecated_on_mutable_data_ptr_ = true; 258 refresh_has_data_ptr_check(); 259 } 260 261 protected: 262 // materialize_cow_storage needs to call set_data_ptr_no_materlize_cow 263 friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage); 264 265 // Returns the previous data_ptr. If the old data_ptr was COW, 266 // this avoids materializing it set_data_ptr_no_materialize_cowStorageImpl267 at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) { 268 at::DataPtr old_data_ptr(std::move(data_ptr_)); 269 data_ptr_ = std::move(data_ptr); 270 refresh_has_data_ptr_check(); 271 return old_data_ptr; 272 } 273 274 private: refresh_has_data_ptr_checkStorageImpl275 void refresh_has_data_ptr_check() { 276 has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ || 277 warn_deprecated_on_mutable_data_ptr_; 278 } 279 is_cowStorageImpl280 inline bool is_cow() const { 281 return c10::impl::cow::is_cow_data_ptr(data_ptr_); 282 } 283 284 // Triggers a copy if this is a copy-on-write tensor. maybe_materialize_cowStorageImpl285 void maybe_materialize_cow() { 286 if (is_cow()) { 287 impl::cow::materialize_cow_storage(*this); 288 } 289 } 290 291 DataPtr data_ptr_; 292 SymInt size_bytes_; 293 bool size_bytes_is_heap_allocated_; 294 bool resizable_; 295 // Identifies that Storage was received from another process and doesn't have 296 // local to process cuda memory allocation 297 bool received_cuda_; 298 // All special checks in data/data_ptr calls are guarded behind this single 299 // boolean. This is for performance: .data/.data_ptr calls are commonly in the 300 // hot-path. 301 bool has_data_ptr_check_ = false; 302 // If we should throw when mutable_data_ptr() or mutable_data() is called. 303 bool throw_on_mutable_data_ptr_ = false; 304 // If we warn when mutable_data_ptr() or mutable_data() is called. 305 bool warn_deprecated_on_mutable_data_ptr_ = false; 306 Allocator* allocator_; 307 impl::PyObjectSlot pyobj_slot_; 308 }; 309 310 // Declare StorageImpl create function pointer types. 311 using StorageImplCreateHelper = intrusive_ptr<StorageImpl> (*)( 312 StorageImpl::use_byte_size_t, 313 SymInt size_bytes, 314 DataPtr data_ptr, 315 Allocator* allocator, 316 bool resizable); 317 318 C10_API void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr); 319 320 C10_API StorageImplCreateHelper GetStorageImplCreate(DeviceType t); 321 322 C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl( 323 c10::StorageImpl::use_byte_size_t use_byte_size, 324 c10::SymInt size_bytes, 325 c10::DataPtr data_ptr, 326 c10::Allocator* allocator, 327 bool resizable, 328 std::optional<at::Device> device_opt); 329 330 } // namespace c10 331