1 #pragma once 2 3 #include <cstddef> 4 #include <cstdint> 5 #include <functional> 6 #include <memory> 7 #include <utility> 8 9 #include <c10/core/Device.h> 10 #include <c10/core/DeviceType.h> 11 #include <c10/macros/Export.h> 12 #include <c10/macros/Macros.h> 13 #include <c10/util/Exception.h> 14 #include <c10/util/ThreadLocalDebugInfo.h> 15 #include <c10/util/UniqueVoidPtr.h> 16 17 namespace c10 { 18 19 // A DataPtr is a unique pointer (with an attached deleter and some 20 // context for the deleter) to some memory, which also records what 21 // device is for its data. 22 // 23 // nullptr DataPtrs can still have a nontrivial device; this allows 24 // us to treat zero-size allocations uniformly with non-zero allocations. 25 // 26 class C10_API DataPtr { 27 private: 28 c10::detail::UniqueVoidPtr ptr_; 29 Device device_; 30 31 public: 32 // Choice of CPU here is arbitrary; if there's an "undefined" device 33 // we could use that too DataPtr()34 DataPtr() : ptr_(), device_(DeviceType::CPU) {} DataPtr(void * data,Device device)35 DataPtr(void* data, Device device) : ptr_(data), device_(device) {} DataPtr(void * data,void * ctx,DeleterFnPtr ctx_deleter,Device device)36 DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device) 37 : ptr_(data, ctx, ctx_deleter), device_(device) {} 38 void* operator->() const { 39 return ptr_.get(); 40 } clear()41 void clear() { 42 ptr_.clear(); 43 } get()44 void* get() const { 45 return ptr_.get(); 46 } mutable_get()47 void* mutable_get() { 48 return ptr_.get(); 49 } get_context()50 void* get_context() const { 51 return ptr_.get_context(); 52 } release_context()53 void* release_context() { 54 return ptr_.release_context(); 55 } move_context()56 std::unique_ptr<void, DeleterFnPtr>&& move_context() { 57 return ptr_.move_context(); 58 } 59 operator bool() const { 60 return static_cast<bool>(ptr_); 61 } 62 template <typename T> cast_context(DeleterFnPtr expected_deleter)63 T* cast_context(DeleterFnPtr expected_deleter) const { 64 return ptr_.cast_context<T>(expected_deleter); 65 } get_deleter()66 DeleterFnPtr get_deleter() const { 67 return ptr_.get_deleter(); 68 } 69 /** 70 * Compare the deleter in a DataPtr to expected_deleter. 71 * If it matches, replace the deleter with new_deleter 72 * and return true; otherwise, does nothing and returns 73 * false. 74 * 75 * In general, it is not safe to unconditionally set the 76 * deleter on a DataPtr, because you don't know what 77 * the deleter is, and thus will have a hard time properly 78 * disposing of the deleter without storing the original 79 * deleter (this is difficult to do, because DeleterFnPtr 80 * is not a closure, and because the context on DataPtr is 81 * only a single word, you generally don't have enough 82 * space to store both the original deleter and its context). 83 * However, in some cases, you know /exactly/ what the deleter 84 * is, and you have a new deleter that manually wraps 85 * the old one. In this case, you can safely swap the deleter 86 * after asserting that the deleters line up. 87 * 88 * What are the requirements on new_deleter? It must still 89 * properly dispose of the void* pointer passed in as its argument, 90 * where void* is whatever the context of the original deleter 91 * is. So in general, you expect the new deleter to look something 92 * like this: 93 * 94 * [](void* ptr) { 95 * some_new_stuff(ptr); 96 * get_orig_allocator()->raw_deleter(ptr); 97 * } 98 * 99 * Note that it won't work to close over the original 100 * allocator; you don't have enough space to do that! Also, 101 * it's unsafe to assume that the passed in pointer in 102 * question is the memory pointer in question; it might not 103 * be; be sure to read the source code of the Allocator 104 * in question to confirm this. 105 */ compare_exchange_deleter(DeleterFnPtr expected_deleter,DeleterFnPtr new_deleter)106 C10_NODISCARD bool compare_exchange_deleter( 107 DeleterFnPtr expected_deleter, 108 DeleterFnPtr new_deleter) { 109 return ptr_.compare_exchange_deleter(expected_deleter, new_deleter); 110 } device()111 Device device() const { 112 return device_; 113 } 114 // Unsafely mutates the device on a DataPtr. Under normal use, 115 // you should never actually need to call this function. 116 // We need this for the implementation of the hack detailed 117 // in Note [Masquerading as CUDA] unsafe_set_device(Device device)118 void unsafe_set_device(Device device) { 119 device_ = device; 120 } 121 }; 122 123 // NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a 124 // CPU nullptr 125 126 inline bool operator==(const DataPtr& dp, std::nullptr_t) noexcept { 127 return !dp; 128 } 129 inline bool operator==(std::nullptr_t, const DataPtr& dp) noexcept { 130 return !dp; 131 } 132 inline bool operator!=(const DataPtr& dp, std::nullptr_t) noexcept { 133 return dp; 134 } 135 inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept { 136 return dp; 137 } 138 139 // Note [raw_allocate/raw_deallocate and Thrust] 140 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 141 // Thrust's support for custom allocators requires us to write something 142 // like this: 143 // 144 // class ThrustAllocator { 145 // char* allocate(size_t); 146 // void deallocate(char*, size_t); 147 // }; 148 // 149 // This is not good for our unique_ptr based allocator interface, as 150 // there is no way to get to the context when we free. 151 // 152 // However, in some cases the context is exactly the same as 153 // the data pointer. In this case, we can support the "raw" 154 // allocate and deallocate interface. This is what 155 // raw_deleter signifies. By default, it returns a nullptr, which means that 156 // the raw interface is not implemented. Be sure to implement it whenever 157 // possible, or the raw interface will incorrectly reported as unsupported, 158 // when it is actually possible. 159 160 struct C10_API Allocator { 161 virtual ~Allocator() = default; 162 163 virtual DataPtr allocate(size_t n) = 0; 164 165 // Clones an allocation that came from this allocator. 166 // 167 // To perform the copy, this function calls `copy_data`, which 168 // must be implemented by derived classes. 169 // 170 // Note that this explicitly ignores any context that may have been 171 // attached to the input data. 172 // 173 // Requires: input data was allocated by the same allocator. 174 DataPtr clone(const void* data, std::size_t n); 175 176 // Checks if DataPtr has a simple context, not wrapped with any out of the 177 // ordinary contexts. 178 virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const; 179 180 // If this returns a non nullptr, it means that allocate() 181 // is guaranteed to return a unique_ptr with this deleter attached; 182 // it means the rawAllocate and rawDeallocate APIs are safe to use. 183 // This function MUST always return the same BoundDeleter. raw_deleterAllocator184 virtual DeleterFnPtr raw_deleter() const { 185 return nullptr; 186 } raw_allocateAllocator187 void* raw_allocate(size_t n) { 188 auto dptr = allocate(n); 189 AT_ASSERT(dptr.get() == dptr.get_context()); 190 return dptr.release_context(); 191 } raw_deallocateAllocator192 void raw_deallocate(void* ptr) { 193 auto d = raw_deleter(); 194 AT_ASSERT(d); 195 d(ptr); 196 } 197 198 // Copies data from one allocation to another. 199 // Pure virtual, so derived classes must define behavior. 200 // Derived class implementation can simply call `default_copy_data` 201 // to use `std::memcpy`. 202 // 203 // Requires: src and dest were allocated by this allocator 204 // Requires: src and dest both have length >= count 205 virtual void copy_data(void* dest, const void* src, std::size_t count) 206 const = 0; 207 208 protected: 209 // Uses `std::memcpy` to copy data. 210 // Child classes can use this as `copy_data` when an alternative copy 211 // API is not needed. 212 void default_copy_data(void* dest, const void* src, std::size_t count) const; 213 }; 214 215 // This context is used to generate DataPtr which have arbitrary 216 // std::function deleters associated with them. In some user facing 217 // functions, we give a (user-friendly) interface for constructing 218 // tensors from external data which take an arbitrary std::function 219 // deleter. Grep for InefficientStdFunctionContext to find these 220 // occurrences. 221 // 222 // This context is inefficient because we have to do a dynamic 223 // allocation InefficientStdFunctionContext, on top of the dynamic 224 // allocation which is implied by std::function itself. 225 struct C10_API InefficientStdFunctionContext { 226 void* ptr_; 227 std::function<void(void*)> deleter_; InefficientStdFunctionContextInefficientStdFunctionContext228 InefficientStdFunctionContext(void* ptr, std::function<void(void*)> deleter) 229 : ptr_(ptr), deleter_(std::move(deleter)) {} ~InefficientStdFunctionContextInefficientStdFunctionContext230 ~InefficientStdFunctionContext() { 231 if (deleter_) { 232 deleter_(ptr_); 233 } 234 } 235 static DataPtr makeDataPtr( 236 void* ptr, 237 std::function<void(void*)> deleter, 238 Device device); 239 }; 240 241 /** Set the allocator for DeviceType `t`. The passed in allocator pointer is 242 * expected to have static lifetime; this function does NOT take ownership 243 * of the raw pointer. (The reason for this is to prevent existing pointers 244 * to an allocator of a particular device from being invalidated when 245 * SetAllocator is called.) 246 * 247 * Also note that this is not thread-safe, and we assume this function will 248 * only be called during initialization. 249 * 250 * The 'priority' flag is introduced when we want to overwrite the default 251 * allocator, since the allocators are set statically. The default priority 252 * is 0, which means the lowest. Only higher or equal priority can overwrite 253 * existing ones. 254 */ 255 C10_API void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0); 256 C10_API Allocator* GetAllocator(const DeviceType& t); 257 258 template <DeviceType t> 259 struct AllocatorRegisterer { AllocatorRegistererAllocatorRegisterer260 explicit AllocatorRegisterer(Allocator* alloc) { 261 SetAllocator(t, alloc); 262 } 263 }; 264 265 #define REGISTER_ALLOCATOR(t, f) \ 266 namespace { \ 267 static c10::AllocatorRegisterer<t> g_allocator_d(f); \ 268 } 269 270 // An interface for reporting thread local memory usage 271 // per device 272 struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase { 273 MemoryReportingInfoBase(); 274 ~MemoryReportingInfoBase() override = default; 275 276 /** 277 * alloc_size corresponds to the size of the ptr. 278 * 279 * total_allocated corresponds to total allocated memory. 280 * 281 * total_reserved corresponds to total size of memory pool, both used and 282 * unused, if applicable. 283 */ 284 virtual void reportMemoryUsage( 285 void* ptr, 286 int64_t alloc_size, 287 size_t total_allocated, 288 size_t total_reserved, 289 Device device) = 0; 290 291 virtual void reportOutOfMemory( 292 int64_t alloc_size, 293 size_t total_allocated, 294 size_t total_reserved, 295 Device device); 296 297 virtual bool memoryProfilingEnabled() const = 0; 298 }; 299 300 C10_API bool memoryProfilingEnabled(); 301 C10_API void reportMemoryUsageToProfiler( 302 void* ptr, 303 int64_t alloc_size, 304 size_t total_allocated, 305 size_t total_reserved, 306 Device device); 307 308 C10_API void reportOutOfMemoryToProfiler( 309 int64_t alloc_size, 310 size_t total_allocated, 311 size_t total_reserved, 312 Device device); 313 314 // used to hold traceback information in allocators 315 struct GatheredContext { 316 virtual ~GatheredContext() = default; 317 }; 318 319 } // namespace c10 320