xref: /aosp_15_r20/external/pytorch/c10/core/Allocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstddef>
4 #include <cstdint>
5 #include <functional>
6 #include <memory>
7 #include <utility>
8 
9 #include <c10/core/Device.h>
10 #include <c10/core/DeviceType.h>
11 #include <c10/macros/Export.h>
12 #include <c10/macros/Macros.h>
13 #include <c10/util/Exception.h>
14 #include <c10/util/ThreadLocalDebugInfo.h>
15 #include <c10/util/UniqueVoidPtr.h>
16 
17 namespace c10 {
18 
19 // A DataPtr is a unique pointer (with an attached deleter and some
20 // context for the deleter) to some memory, which also records what
21 // device is for its data.
22 //
23 // nullptr DataPtrs can still have a nontrivial device; this allows
24 // us to treat zero-size allocations uniformly with non-zero allocations.
25 //
26 class C10_API DataPtr {
27  private:
28   c10::detail::UniqueVoidPtr ptr_;
29   Device device_;
30 
31  public:
32   // Choice of CPU here is arbitrary; if there's an "undefined" device
33   // we could use that too
DataPtr()34   DataPtr() : ptr_(), device_(DeviceType::CPU) {}
DataPtr(void * data,Device device)35   DataPtr(void* data, Device device) : ptr_(data), device_(device) {}
DataPtr(void * data,void * ctx,DeleterFnPtr ctx_deleter,Device device)36   DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device)
37       : ptr_(data, ctx, ctx_deleter), device_(device) {}
38   void* operator->() const {
39     return ptr_.get();
40   }
clear()41   void clear() {
42     ptr_.clear();
43   }
get()44   void* get() const {
45     return ptr_.get();
46   }
mutable_get()47   void* mutable_get() {
48     return ptr_.get();
49   }
get_context()50   void* get_context() const {
51     return ptr_.get_context();
52   }
release_context()53   void* release_context() {
54     return ptr_.release_context();
55   }
move_context()56   std::unique_ptr<void, DeleterFnPtr>&& move_context() {
57     return ptr_.move_context();
58   }
59   operator bool() const {
60     return static_cast<bool>(ptr_);
61   }
62   template <typename T>
cast_context(DeleterFnPtr expected_deleter)63   T* cast_context(DeleterFnPtr expected_deleter) const {
64     return ptr_.cast_context<T>(expected_deleter);
65   }
get_deleter()66   DeleterFnPtr get_deleter() const {
67     return ptr_.get_deleter();
68   }
69   /**
70    * Compare the deleter in a DataPtr to expected_deleter.
71    * If it matches, replace the deleter with new_deleter
72    * and return true; otherwise, does nothing and returns
73    * false.
74    *
75    * In general, it is not safe to unconditionally set the
76    * deleter on a DataPtr, because you don't know what
77    * the deleter is, and thus will have a hard time properly
78    * disposing of the deleter without storing the original
79    * deleter (this is difficult to do, because DeleterFnPtr
80    * is not a closure, and because the context on DataPtr is
81    * only a single word, you generally don't have enough
82    * space to store both the original deleter and its context).
83    * However, in some cases, you know /exactly/ what the deleter
84    * is, and you have a new deleter that manually wraps
85    * the old one.  In this case, you can safely swap the deleter
86    * after asserting that the deleters line up.
87    *
88    * What are the requirements on new_deleter?  It must still
89    * properly dispose of the void* pointer passed in as its argument,
90    * where void* is whatever the context of the original deleter
91    * is.  So in general, you expect the new deleter to look something
92    * like this:
93    *
94    *      [](void* ptr) {
95    *        some_new_stuff(ptr);
96    *        get_orig_allocator()->raw_deleter(ptr);
97    *      }
98    *
99    * Note that it won't work to close over the original
100    * allocator; you don't have enough space to do that!  Also,
101    * it's unsafe to assume that the passed in pointer in
102    * question is the memory pointer in question; it might not
103    * be; be sure to read the source code of the Allocator
104    * in question to confirm this.
105    */
compare_exchange_deleter(DeleterFnPtr expected_deleter,DeleterFnPtr new_deleter)106   C10_NODISCARD bool compare_exchange_deleter(
107       DeleterFnPtr expected_deleter,
108       DeleterFnPtr new_deleter) {
109     return ptr_.compare_exchange_deleter(expected_deleter, new_deleter);
110   }
device()111   Device device() const {
112     return device_;
113   }
114   // Unsafely mutates the device on a DataPtr.  Under normal use,
115   // you should never actually need to call this function.
116   // We need this for the implementation of the hack detailed
117   // in Note [Masquerading as CUDA]
unsafe_set_device(Device device)118   void unsafe_set_device(Device device) {
119     device_ = device;
120   }
121 };
122 
123 // NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a
124 // CPU nullptr
125 
126 inline bool operator==(const DataPtr& dp, std::nullptr_t) noexcept {
127   return !dp;
128 }
129 inline bool operator==(std::nullptr_t, const DataPtr& dp) noexcept {
130   return !dp;
131 }
132 inline bool operator!=(const DataPtr& dp, std::nullptr_t) noexcept {
133   return dp;
134 }
135 inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept {
136   return dp;
137 }
138 
139 // Note [raw_allocate/raw_deallocate and Thrust]
140 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141 // Thrust's support for custom allocators requires us to write something
142 // like this:
143 //
144 //  class ThrustAllocator {
145 //    char* allocate(size_t);
146 //    void deallocate(char*, size_t);
147 //  };
148 //
149 // This is not good for our unique_ptr based allocator interface, as
150 // there is no way to get to the context when we free.
151 //
152 // However, in some cases the context is exactly the same as
153 // the data pointer.  In this case, we can support the "raw"
154 // allocate and deallocate interface.  This is what
155 // raw_deleter signifies.  By default, it returns a nullptr, which means that
156 // the raw interface is not implemented.  Be sure to implement it whenever
157 // possible, or the raw interface will incorrectly reported as unsupported,
158 // when it is actually possible.
159 
160 struct C10_API Allocator {
161   virtual ~Allocator() = default;
162 
163   virtual DataPtr allocate(size_t n) = 0;
164 
165   // Clones an allocation that came from this allocator.
166   //
167   // To perform the copy, this function calls `copy_data`, which
168   // must be implemented by derived classes.
169   //
170   // Note that this explicitly ignores any context that may have been
171   // attached to the input data.
172   //
173   // Requires: input data was allocated by the same allocator.
174   DataPtr clone(const void* data, std::size_t n);
175 
176   // Checks if DataPtr has a simple context, not wrapped with any out of the
177   // ordinary contexts.
178   virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const;
179 
180   // If this returns a non nullptr, it means that allocate()
181   // is guaranteed to return a unique_ptr with this deleter attached;
182   // it means the rawAllocate and rawDeallocate APIs are safe to use.
183   // This function MUST always return the same BoundDeleter.
raw_deleterAllocator184   virtual DeleterFnPtr raw_deleter() const {
185     return nullptr;
186   }
raw_allocateAllocator187   void* raw_allocate(size_t n) {
188     auto dptr = allocate(n);
189     AT_ASSERT(dptr.get() == dptr.get_context());
190     return dptr.release_context();
191   }
raw_deallocateAllocator192   void raw_deallocate(void* ptr) {
193     auto d = raw_deleter();
194     AT_ASSERT(d);
195     d(ptr);
196   }
197 
198   // Copies data from one allocation to another.
199   // Pure virtual, so derived classes must define behavior.
200   // Derived class implementation can simply call `default_copy_data`
201   // to use `std::memcpy`.
202   //
203   // Requires: src and dest were allocated by this allocator
204   // Requires: src and dest both have length >= count
205   virtual void copy_data(void* dest, const void* src, std::size_t count)
206       const = 0;
207 
208  protected:
209   // Uses `std::memcpy` to copy data.
210   // Child classes can use this as `copy_data` when an alternative copy
211   // API is not needed.
212   void default_copy_data(void* dest, const void* src, std::size_t count) const;
213 };
214 
215 // This context is used to generate DataPtr which have arbitrary
216 // std::function deleters associated with them.  In some user facing
217 // functions, we give a (user-friendly) interface for constructing
218 // tensors from external data which take an arbitrary std::function
219 // deleter.  Grep for InefficientStdFunctionContext to find these
220 // occurrences.
221 //
222 // This context is inefficient because we have to do a dynamic
223 // allocation InefficientStdFunctionContext, on top of the dynamic
224 // allocation which is implied by std::function itself.
225 struct C10_API InefficientStdFunctionContext {
226   void* ptr_;
227   std::function<void(void*)> deleter_;
InefficientStdFunctionContextInefficientStdFunctionContext228   InefficientStdFunctionContext(void* ptr, std::function<void(void*)> deleter)
229       : ptr_(ptr), deleter_(std::move(deleter)) {}
~InefficientStdFunctionContextInefficientStdFunctionContext230   ~InefficientStdFunctionContext() {
231     if (deleter_) {
232       deleter_(ptr_);
233     }
234   }
235   static DataPtr makeDataPtr(
236       void* ptr,
237       std::function<void(void*)> deleter,
238       Device device);
239 };
240 
241 /** Set the allocator for DeviceType `t`. The passed in allocator pointer is
242  *  expected to have static lifetime; this function does NOT take ownership
243  *  of the raw pointer. (The reason for this is to prevent existing pointers
244  *  to an allocator of a particular device from being invalidated when
245  *  SetAllocator is called.)
246  *
247  *  Also note that this is not thread-safe, and we assume this function will
248  *  only be called during initialization.
249  *
250  *  The 'priority' flag is introduced when we want to overwrite the default
251  *  allocator, since the allocators are set statically. The default priority
252  *  is 0, which means the lowest. Only higher or equal priority can overwrite
253  *  existing ones.
254  */
255 C10_API void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0);
256 C10_API Allocator* GetAllocator(const DeviceType& t);
257 
258 template <DeviceType t>
259 struct AllocatorRegisterer {
AllocatorRegistererAllocatorRegisterer260   explicit AllocatorRegisterer(Allocator* alloc) {
261     SetAllocator(t, alloc);
262   }
263 };
264 
265 #define REGISTER_ALLOCATOR(t, f)                       \
266   namespace {                                          \
267   static c10::AllocatorRegisterer<t> g_allocator_d(f); \
268   }
269 
270 // An interface for reporting thread local memory usage
271 // per device
272 struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
273   MemoryReportingInfoBase();
274   ~MemoryReportingInfoBase() override = default;
275 
276   /**
277    * alloc_size corresponds to the size of the ptr.
278    *
279    * total_allocated corresponds to total allocated memory.
280    *
281    * total_reserved corresponds to total size of memory pool, both used and
282    * unused, if applicable.
283    */
284   virtual void reportMemoryUsage(
285       void* ptr,
286       int64_t alloc_size,
287       size_t total_allocated,
288       size_t total_reserved,
289       Device device) = 0;
290 
291   virtual void reportOutOfMemory(
292       int64_t alloc_size,
293       size_t total_allocated,
294       size_t total_reserved,
295       Device device);
296 
297   virtual bool memoryProfilingEnabled() const = 0;
298 };
299 
300 C10_API bool memoryProfilingEnabled();
301 C10_API void reportMemoryUsageToProfiler(
302     void* ptr,
303     int64_t alloc_size,
304     size_t total_allocated,
305     size_t total_reserved,
306     Device device);
307 
308 C10_API void reportOutOfMemoryToProfiler(
309     int64_t alloc_size,
310     size_t total_allocated,
311     size_t total_reserved,
312     Device device);
313 
314 // used to hold traceback information in allocators
315 struct GatheredContext {
316   virtual ~GatheredContext() = default;
317 };
318 
319 } // namespace c10
320