xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAMallocAsyncAllocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/cuda/CUDACachingAllocator.h>
2 #include <c10/cuda/CUDAException.h>
3 #include <c10/cuda/CUDAFunctions.h>
4 #include <c10/cuda/CUDAGuard.h>
5 #include <c10/util/UniqueVoidPtr.h>
6 #include <c10/util/flat_hash_map.h>
7 #include <c10/util/irange.h>
8 
9 #include <unordered_set>
10 #include <vector>
11 
12 namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync {
13 
14 using namespace c10::CachingDeviceAllocator;
15 
16 #if CUDA_VERSION >= 11040
17 // CUDA device allocator that uses cudaMallocAsync to implement
18 // the same interface as CUDACachingAllocator.cpp.
19 
20 // Designed to be safe for CUDA graph capture.
21 // Interactions with CUDA graph capture are mediated by
22 // notifyCaptureBegin
23 // notifyCaptureAboutToEnd
24 // notifyCaptureEnded
25 // notifyCaptureDestroy
26 
27 // Implementation details, not declared in CUDACachingAllocator.h
28 namespace {
29 
30 // General helpers
31 
32 struct UsageStream {
33   cudaStream_t stream;
34   c10::DeviceIndex device;
35   UsageStream() = default;
UsageStreamc10::cuda::CUDACachingAllocator::CudaMallocAsync::__anon5214cc830111::UsageStream36   UsageStream(cudaStream_t s, c10::DeviceIndex d) : stream(s), device(d) {}
37   UsageStream(const UsageStream& us) = default;
38   UsageStream(UsageStream&& us) noexcept = default;
39   UsageStream& operator=(const UsageStream& other) = default;
40   UsageStream& operator=(UsageStream&& other) noexcept = default;
41 };
42 
operator ==(const UsageStream & lhs,const UsageStream & rhs)43 bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
44   return (lhs.stream == rhs.stream) && (lhs.device == rhs.device);
45 }
46 
47 struct UsageStreamHash {
operator ()c10::cuda::CUDACachingAllocator::CudaMallocAsync::__anon5214cc830111::UsageStreamHash48   size_t operator()(const UsageStream& us) const noexcept {
49     return std::hash<void*>{}(us.stream) + size_t(us.device);
50   }
51 };
52 
53 struct PtrUsage {
54   // recorded_streams holds side usage streams added by record_stream calls.
55   // In other words, it does NOT include the original creation stream.
56   ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
57   UsageStream creation_stream{};
58   uint64_t size;
59   bool captured;
PtrUsagec10::cuda::CUDACachingAllocator::CudaMallocAsync::__anon5214cc830111::PtrUsage60   PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
61 };
62 
63 int device_count = 0;
64 // these don't need to be c10::once_flags as in CUDAGeneratorImpl.cpp
65 // because they'll only be flipped by functions that have locked the mutex.
66 std::vector<bool> devs_initialized_flags;
67 std::vector<UsageStream> dummy_unifying_free_streams;
68 
69 // Possible micro-optimization:
70 // Some accesses to ptr_info are read-only.
71 // We could let those be concurrent with a shared_mutex and
72 // have concurrent calls take a shared_lock.
73 // Keeping it simple with an ordinary mutex for now.
74 std::mutex general_mutex;
75 
76 /**
77  * Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
78  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79  * During CUDA graph capture, it's illegal to call cudaFreeAsync
80  * on a pointer that came from a non-captured cudaMallocAsync.
81  * Unfortunately, Python being what it is, it's impossible to be
82  * sure no uncaptured tensor will ever have its destructor called
83  * in a capturing region.
84  * We avoid errors by
85  *  1. remembering if allocated pointers were captured or uncaptured
86  *  2. during capture, if we detect an attempt to free an uncaptured
87  *     allocation on a capturing stream, don't free it immediately,
88  *     just remember it and defer its cudaFreeAsync call to after
89  *     the end of capture (specifically, to notifyCaptureEnded).
90  */
91 
92 using PtrInfo = ska::flat_hash_map<void*, PtrUsage>;
93 PtrInfo ptr_info;
94 std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture;
95 
96 // These two help setMemoryFraction limit the amount of memory
97 // used by PyTorch in particular (as opposed to other libraries
98 // in the same process that might be sharing the same cudaMemPool_t).
99 std::vector<size_t> pytorch_used_bytes;
100 std::vector<size_t> pytorch_memory_limits;
101 
102 // Graph-specific helpers
103 
104 /**
105  * Note [Avoid dangling free streams during CUDA graph capture]
106  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
107  * During capture, all stream dependencies must branch out from
108  * the stream on which capture began and rejoin this initial stream
109  * before capture ends.
110  * The user rigs desired forking and joining with event waits.
111  * But it's hard to be sure when tensor destructors get called relative
112  * to the final joins.
113  * For example, suppose a user
114  *   forks work stream B from initial capture stream A
115  *   creates a tensor T in B
116  *   joins by syncing A with B
117  *   ends capture.
118  * All well and good, right? Maybe not: maybe T went out of scope
119  * and its destructor got called AFTER the rejoin, leaving the graph with
120  * "unjoined work": a dangling cudaFreeAsync node in stream B.
121  * Ensuring that all tensor destructors for all side stream tensors
122  * are called before side streams rejoin the main stream is
123  * difficult. The user might have to add a bunch of explicit
124  * "del"s at the right spots in code that was fine for ordinary
125  * eager execution.
126  * Fortunately, we can spare the user this burden:
127  * during capture, we remember _all_ free streams,
128  * and manually rejoin them with the capture stream during
129  * notifyCaptureAboutToEnd.
130  * This approach is heavy-handed, but hopefully capture only needs to
131  * happen once, so we don't mind being heavy-handed.
132  *
133  * TODO: If, someday, we augment the graph bindings to support recapture
134  * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#whole-graph-update
135  * (eg, as a way to accommodate dynamic params) we should think more
136  * carefully about the CPU overhead of remembering and rejoining
137  * all free streams during capture. Maybe it's not a big deal.
138  */
139 std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams;
140 bool capture_underway = false;
141 
142 // Implementation functions
143 
144 // Assumes the caller holds general_mutex
lazy_init_device(c10::DeviceIndex device)145 inline void lazy_init_device(c10::DeviceIndex device) {
146   if (!devs_initialized_flags[device]) {
147     CUDAGuard g(device);
148 
149     // See "Retaining memory in the pool" here:
150     // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
151     cudaMemPool_t mempool = nullptr;
152     C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
153     uint64_t threshold = UINT64_MAX;
154     C10_CUDA_CHECK(cudaMemPoolSetAttribute(
155         mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
156 
157     // I think all these are on by default, but I want to enable them
158     // explicitly to ensure awareness.
159     int enable = 1;
160     C10_CUDA_CHECK(cudaMemPoolSetAttribute(
161         mempool, cudaMemPoolReuseFollowEventDependencies, &enable));
162     C10_CUDA_CHECK(cudaMemPoolSetAttribute(
163         mempool, cudaMemPoolReuseAllowOpportunistic, &enable));
164     C10_CUDA_CHECK(cudaMemPoolSetAttribute(
165         mempool, cudaMemPoolReuseAllowInternalDependencies, &enable));
166 
167     // Grabs a stream from the current device to use as the "unifier" free
168     // stream for allocations that end up used on multiple streams.
169     const auto dufs = getStreamFromPool();
170     dummy_unifying_free_streams[device] =
171         UsageStream(dufs.stream(), dufs.device_index());
172 
173     pytorch_used_bytes[device] = 0;
174     pytorch_memory_limits[device] = UINT64_MAX;
175 
176     devs_initialized_flags[device] = true;
177   }
178 }
179 
sync_raw(cudaStream_t dependency,cudaStream_t dependent)180 inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
181   // CUDACachingAllocator.cpp uses raw cuda events, as do we.
182   cudaEvent_t event = nullptr;
183   C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
184   C10_CUDA_CHECK(cudaEventRecord(event, dependency));
185   C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
186   C10_CUDA_CHECK(cudaEventDestroy(event));
187 }
188 
189 // Assumes the caller holds general_mutex
free_impl(PtrInfo::iterator & it)190 inline void free_impl(PtrInfo::iterator& it) {
191   // Possible micro-optimization: If we did a value-copy here, we could move
192   // ptr_info.erase(it) up here and drop the lock immediately.
193   const auto& recorded_streams = it->second.recorded_streams;
194   const auto& creation_stream = it->second.creation_stream;
195 
196   // If the usage stream is a null (default) stream,
197   // cudaFreeAsync infers the device from the ambient context,
198   // so we need to set the right ambient context.
199   CUDAGuard g(creation_stream.device);
200 
201   if (recorded_streams.empty()) {
202     // ptr was only used on one stream, which must have been
203     // the original allocation stream.
204     // Frees ptr in the original allocation stream.
205 
206     C10_CUDA_CHECK(cudaFreeAsync(it->first, creation_stream.stream));
207 
208     if (C10_UNLIKELY(capture_underway)) {
209       // See Note [Avoid dangling free streams during CUDA graph capture]
210       capture_free_streams.insert(creation_stream);
211     }
212   } else {
213     // ptr was used on many streams. We don't know which was the most recent.
214     // There could even have been multiple most recent usage streams acting
215     // on different regions of the memory.
216     // But cudaFreeAsync only accepts a single most recent usage stream.
217     // We can still safely free ptr with a trick:
218     // Use a dummy "unifying stream", sync the unifying stream with all of
219     // ptr's usage streams, and pass the dummy stream to cudaFreeAsync.
220 
221     // Retrieves the dummy "unifier" stream from the device
222     // on which the pointer was originally allocated.
223     auto dummy_unifying_free_stream =
224         dummy_unifying_free_streams[creation_stream.device];
225     TORCH_INTERNAL_ASSERT(
226         dummy_unifying_free_stream.device == creation_stream.device);
227 
228     // we're already on creation_stream.device, no need to re-guard
229     sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream);
230 
231     // The number of usage streams is typically small (low single digits)
232     for (const auto& recorded_stream : recorded_streams) {
233       // Logic here accommodates the chance some of the usage streams were on
234       // other devices, which is possible if some usage kernels accessed the
235       // memory via p2p.
236 
237       // cudaEventRecord requires that the input event and stream are on the
238       // same device.
239       CUDAGuard g_usage(recorded_stream.device);
240 
241       sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
242     }
243 
244     // Frees ptr in the dummy "unifier" stream.
245     C10_CUDA_CHECK(cudaFreeAsync(it->first, dummy_unifying_free_stream.stream));
246     // At this point, unless dummy_unifying_free_stream happens to alias some
247     // future user stream, the allocation is only available for "opportunistic"
248     // reuse, ie, if the CPU sees dummy_unifying_free_stream has reached the
249     // point that all events recorded on all usage streams have resolved from
250     // the CPU's perspective. In theory, we could remove the need for the driver
251     // to do this tracking by e.g. replacing
252     // cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event);
253     // with
254     // cudaStreamWaitEvent(creation_stream.stream, event);
255     // then cudaFreeAsyncing straight back into creation_stream.stream,
256     // but this forces a potentially false dependency of creation_stream.stream
257     // on all the recorded_streams.
258 
259     if (C10_UNLIKELY(capture_underway)) {
260       // See Note [Avoid dangling free streams during CUDA graph capture]
261       capture_free_streams.emplace(
262           dummy_unifying_free_stream.stream, dummy_unifying_free_stream.device);
263     }
264   }
265 
266   pytorch_used_bytes[creation_stream.device] -= it->second.size;
267 
268   ptr_info.erase(it);
269 }
270 
freeAsync(void * ptr)271 void freeAsync(void* ptr) {
272   std::lock_guard<std::mutex> lk(general_mutex);
273 
274   auto err = cudaGetLastError();
275   C10_CUDA_CHECK(err);
276   auto it = ptr_info.find(ptr);
277   TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
278 
279   if (C10_UNLIKELY(capture_underway)) {
280     if (!it->second.captured) {
281       TORCH_WARN_ONCE(
282           "freeAsync() was called on an uncaptured allocation during graph capture "
283           "(address = ",
284           ptr,
285           "). This may be benign, for example, a Python tensor in the capture "
286           "might happen to shadow (use the same name as) an unrelated temporary "
287           "tensor from somewhere before capture, pushing the earlier tensor "
288           "out of scope. "
289           "However, if the tensor we're freeing here IS used by the capture, "
290           "freeing it is an error, and may cause illegal memory accesses or "
291           "memory corruption during graph replay.");
292       // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
293       // Remembers the raw pointer, not the iterator.
294       // This forces notifyCaptureEnded to do another lookup,
295       // but avoids the risk the iterator might be invalidated
296       // between now and then.
297       ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr);
298       return;
299     }
300   } else if (C10_UNLIKELY(it->second.captured)) {
301     TORCH_WARN(
302         "Attempting uncaptured free of a captured allocation with address ",
303         ptr,
304         "\nThis is technically allowed, but may indicate you are losing "
305         "the last user-visible tensor through which the allocation can "
306         "be accessed, so you'll have no way to view the data after "
307         "future replays of the owning graph.");
308   }
309 
310   free_impl(it);
311 }
312 
313 // Symmetric with NativeCachingAllocator::malloc for now,
314 // although I don't think we absolutely need the symmetry.
mallocAsync(void ** devPtr,c10::DeviceIndex device,size_t size,cudaStream_t stream)315 void mallocAsync(
316     void** devPtr,
317     c10::DeviceIndex device,
318     size_t size,
319     cudaStream_t stream) {
320   TORCH_INTERNAL_ASSERT(
321       0 <= device && device < device_count,
322       "Invalid device index ",
323       device,
324       ": did you call init?");
325 
326   // If stream is a null (default) stream,
327   // cudaMallocAsync infers the device from the ambient context,
328   // so we need to set the right ambient context.
329   CUDAGuard g(device);
330 
331   std::lock_guard<std::mutex> lk(general_mutex);
332 
333   if (!capture_underway &&
334       !ungraphed_ptrs_defer_free_until_no_capture.empty()) {
335     // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
336     for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) {
337       auto it = ptr_info.find(ptr);
338       TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
339       free_impl(it);
340     }
341 
342     ungraphed_ptrs_defer_free_until_no_capture.clear();
343   }
344 
345   lazy_init_device(device);
346 
347   // Defensively checks for preexisting CUDA error state.
348   auto err = cudaGetLastError();
349   C10_CUDA_CHECK(err);
350 
351   // TODO: Could we avoid calling cudaMallocAsync while holding general_mutex,
352   // perhaps by letting lazy_init_device use separate once_flags or an internal
353   // static initializer?
354   if (pytorch_used_bytes[device] + size > pytorch_memory_limits[device]) {
355     err = cudaErrorMemoryAllocation;
356   } else {
357     err = cudaMallocAsync(devPtr, size, stream);
358   }
359 
360   if (err == cudaErrorMemoryAllocation) {
361     // Clears CUDA's internal error state so the user, if desired, can catch the
362     // OOM exception, free some stuff on the script side, and retry the
363     // allocation. This aligns with the behavior of alloc_block in
364     // CUDACachingAllocator.cpp.
365     (void)cudaGetLastError(); // clear CUDA error
366     size_t device_free = 0;
367     size_t device_total = 0;
368     C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
369     TORCH_CHECK_WITH(
370         OutOfMemoryError,
371         false,
372         "Allocation on device ",
373         device,
374         " would exceed allowed memory. (out of memory)",
375         "\nCurrently allocated     : ",
376         format_size(pytorch_used_bytes[device]),
377         "\nRequested               : ",
378         format_size(size),
379         "\nDevice limit            : ",
380         format_size(device_total),
381         "\nFree (according to CUDA): ",
382         format_size(device_free),
383         "\nPyTorch limit (set by user-supplied memory fraction)"
384         "\n                        : ",
385         format_size(pytorch_memory_limits[device]));
386   } else {
387     C10_CUDA_CHECK(err);
388   }
389 
390   auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
391   TORCH_INTERNAL_ASSERT(
392       inserted.second,
393       "address returned by cudaMallocAsync already exists "
394       "in ptr_info");
395 
396   inserted.first->second.creation_stream = {stream, device};
397 
398   pytorch_used_bytes[device] += size;
399 }
400 
401 } // anonymous namespace
402 
403 void local_raw_delete(void* ptr);
404 
405 // Same pattern as CUDACachingAllocator.cpp.
406 struct CudaMallocAsyncAllocator : public CUDAAllocator {
allocatec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator407   DataPtr allocate(size_t size) override {
408     constexpr size_t one_exa_bytes = 1152921504606846976ULL;
409     TORCH_CHECK_WITH(
410         OutOfMemoryError,
411         size < one_exa_bytes,
412         "CUDA out of memory. Tried to allocate more than 1EB memory.");
413     c10::DeviceIndex device = 0;
414     C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
415     void* r = nullptr;
416     if (size != 0) {
417       mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
418     }
419     return {r, r, &local_raw_delete, Device(DeviceType::CUDA, device)};
420   }
raw_deleterc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator421   DeleterFnPtr raw_deleter() const override {
422     return &local_raw_delete;
423   }
424 
425   // This function should not issue any context-creating calls,
426   // just set up for later calls to init per-device pools based
427   // on the current device each later call sees.
initc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator428   void init(int dev_count) override {
429     static bool called = [](int dev_count) {
430       ;
431       // Are there external guarantees init will be called before
432       // any of the allocator's other functions?
433       // std::lock_guard<std::mutex> lk(general_mutex);
434       device_count = dev_count;
435       devs_initialized_flags.resize(dev_count, false);
436       dummy_unifying_free_streams.resize(dev_count);
437       pytorch_used_bytes.resize(dev_count);
438       pytorch_memory_limits.resize(dev_count);
439       return true;
440     }(dev_count);
441     (void)called;
442   }
443 
initializedc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator444   bool initialized() override {
445     return !devs_initialized_flags.empty();
446   }
447 
assertValidDevicec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator448   static inline void assertValidDevice(c10::DeviceIndex device) {
449     TORCH_CHECK(
450         0 <= device && device < device_count, "Invalid device argument.");
451   }
452 
setMemoryFractionc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator453   void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
454     TORCH_INTERNAL_ASSERT(
455         0 <= fraction && fraction <= 1,
456         "invalid fraction:",
457         fraction,
458         ". Please set within (0, 1).");
459 
460     std::lock_guard<std::mutex> lk(general_mutex);
461     assertValidDevice(device);
462     CUDAGuard g(device);
463     // Should setMemoryFraction be allowed to trigger a full device context and
464     // pool-creating lazy_init_device, or should we simply assert this device is
465     // already initialized, ie
466     // TORCH_CHECK(devs_initialized_flags[device], ...)?
467     lazy_init_device(device);
468 
469     size_t device_free = 0;
470     size_t device_total = 0;
471     C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
472     pytorch_memory_limits[device] =
473         static_cast<uint64_t>(fraction * static_cast<double>(device_total));
474 
475     // Alternative: Instead of a manual hard limit, we could use
476     // cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold,
477     // &threshold); This is a soft hint: The driver allows the pool's reserved
478     // memory to spike above threshold in regions of high cudaMallocAsync
479     // demand, but opportunistically trims reserved memory back to threshold
480     // when the memory in use is < threshold. I don't like this because it
481     // introduces performance nondeterminism.
482   }
483 
emptyCachec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator484   void emptyCache() override {
485     std::lock_guard<std::mutex> lk(general_mutex);
486 
487     for (int dev = 0; dev < device_count; dev++) {
488       if (devs_initialized_flags[dev]) {
489         CUDAGuard g(static_cast<c10::DeviceIndex>(dev));
490 
491         cudaMemPool_t mempool = nullptr;
492         cudaDeviceGetDefaultMemPool(&mempool, dev);
493         cudaDeviceSynchronize();
494         cudaMemPoolTrimTo(mempool, 0);
495       }
496     }
497   }
498 
cacheInfoc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator499   void cacheInfo(c10::DeviceIndex device, size_t* maxWorkspaceGuess) override {
500     // The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp.
501     // Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable
502     // maximum workspace size to use for an upcoming cudnnFind call.
503     //
504     // The native allocator's cacheInfo chooses to return the size of its
505     // largest unused block (which is the largest allocation the native
506     // allocator can service immediately and asynchronously without a
507     // cudaMalloc.
508     //
509     // Here, we use a different heuristic: figure out the max usable workspace
510     // size with a bit of educated trial and error. It's ok to be
511     // perf-inefficient because cacheInfo is a prelude to cudnnFind.
512     //
513     // The algo cache then stores the best-performing algo with workspace <=
514     // maxWorkspaceGuess. Later calls with the same param set hit in cache and
515     // try to allocate the same workspace. If, in one of those future calls,
516     // workspace allocation fails (ie because less ambient memory is available),
517     // the bindings rerun cudnnFind, including calling cacheInfo again
518     // beforehand to estimate a new (smaller) largest-available workspace. Over
519     // a few such calls, the cache should settle to the algo with a workspace
520     // size that's small enough to succeed every time (for that param set).
521     //
522     // So the strategy here is to return a rough, largeish guess and let the
523     // bindings retry to trim as needed over time.
524     //
525     // The only caveat is, even if a workspace is allocated without OOM errors
526     // now and in future calls, it's hard to be sure those later error-free
527     // cudaMallocAsyncs are fast and come straight from the pool (ie,
528     // cudaMallocAsync didn't need to reserve more memory from the system).
529     // Hopefully, after repeated workspace requests, the pool's reserved memory
530     // also stabilizes to a point where they all come straight from the pool.
531     std::lock_guard<std::mutex> lk(general_mutex);
532     assertValidDevice(device);
533     CUDAGuard g(device);
534     lazy_init_device(device);
535 
536     size_t free_upper_bound = 0;
537     size_t device_total = 0;
538     C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total));
539     TORCH_INTERNAL_ASSERT(
540         free_upper_bound + pytorch_used_bytes[device] <= device_total);
541     size_t guess = std::min(
542         free_upper_bound,
543         pytorch_memory_limits[device] - pytorch_used_bytes[device]);
544     auto stream = c10::cuda::getCurrentCUDAStream();
545     void* dummy = nullptr;
546 
547     // Defensively checks for preexisting CUDA error state.
548     auto err = cudaGetLastError();
549     C10_CUDA_CHECK(err);
550 
551     while (true) {
552       // Duplicates some logic from mallocAsync to work with the error state
553       // directly instead of repeatedly catching an exception thrown by
554       // mallocAsync.
555       if (pytorch_used_bytes[device] + guess > pytorch_memory_limits[device]) {
556         err = cudaErrorMemoryAllocation;
557       } else {
558         err = cudaMallocAsync(&dummy, guess, stream);
559       }
560 
561       if (err == cudaSuccess) {
562         cudaFreeAsync(dummy, stream);
563         *maxWorkspaceGuess = guess;
564         return;
565       } else if (err == cudaErrorMemoryAllocation) {
566         (void)cudaGetLastError(); // clear CUDA error
567         guess >>= 1; // quick and dirty: try half the size next iteration
568       } else {
569         C10_CUDA_CHECK(err);
570       }
571     }
572   }
573 
getBaseAllocationc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator574   void* getBaseAllocation(void* ptr, size_t* size) override {
575     std::lock_guard<std::mutex> lk(general_mutex);
576 
577     auto it = ptr_info.find(ptr);
578     TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
579 
580     if (size) {
581       *size = it->second.size;
582     }
583 
584     return ptr;
585   }
586 
recordStreamc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator587   void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override {
588     std::lock_guard<std::mutex> lk(general_mutex);
589     auto ptr_val = ptr.get();
590     // Empty tensor's storage().data() might be a null ptr. As there is no
591     // blocks associated with those tensors, it is fine to do nothing here.
592     if (!ptr_val) {
593       return;
594     }
595 
596     // The pointer should exist in the map already.
597     auto it = ptr_info.find(ptr_val);
598     TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
599 
600     UsageStream to_record{stream.stream(), stream.device_index()};
601     if (to_record == it->second.creation_stream) {
602       TORCH_WARN_ONCE(
603           "Called record_stream on tensor whose original creation stream "
604           "matches the recorded stream. This is unnecessary and has no effect.");
605     } else {
606       it->second.recorded_streams.insert(to_record);
607     }
608   }
609 
shareIpcHandlec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator610   ShareableHandle shareIpcHandle(void* handle) override {
611     TORCH_CHECK(
612         false,
613         "cudaMallocAsync does not yet support shareIpcHandle. "
614         "If you need it, please file an issue describing your use case.");
615   }
616 
getIpcDevPtrc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator617   std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
618     TORCH_CHECK(
619         false,
620         "cudaMallocAsync does not yet support getIpcDevPtr. "
621         "If you need it, please file an issue describing your use case.");
622   }
623 
recordHistoryc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator624   void recordHistory(
625       bool enabled,
626       CreateContextFn context_recorder,
627       size_t alloc_trace_max_entries,
628       RecordContext when) override {
629     TORCH_CHECK(
630         false,
631         "cudaMallocAsync does not yet support recordHistory. "
632         "If you need it, please file an issue describing your use case.");
633   }
634 
attachOutOfMemoryObserverc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator635   void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
636     TORCH_CHECK(
637         false,
638         "cudaMallocAsync does not yet support attachOutOfMemoryObserver. "
639         "If you need it, please file an issue describing your use case.");
640   }
641 
attachAllocatorTraceTrackerc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator642   void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override {
643     TORCH_CHECK(
644         false,
645         "cudaMallocAsync does not yet support attachAllocatorTraceTracker. "
646         "If you need it, please file an issue describing your use case.");
647   }
648 
getCheckpointStatec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator649   std::shared_ptr<AllocatorState> getCheckpointState(
650       c10::DeviceIndex device,
651       MempoolId_t id) override {
652     TORCH_CHECK(
653         false,
654         "cudaMallocAsync does not yet support getCheckpointState. "
655         "If you need it, please file an issue describing your use case.");
656   }
657 
setCheckpointPoolStatec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator658   CheckpointDelta setCheckpointPoolState(
659       c10::DeviceIndex device,
660       std::shared_ptr<AllocatorState> pps) override {
661     TORCH_CHECK(
662         false,
663         "cudaMallocAsync does not yet support setCheckpointPoolState. "
664         "If you need it, please file an issue describing your use case.");
665   }
666 
667   // Collects stats for device.
668   // If device hasn't been used yet, returns 0s without creating a context.
getDeviceStatsc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator669   DeviceStats getDeviceStats(c10::DeviceIndex device) override {
670     assertValidDevice(device);
671 
672     // Memory currently reserved by the mempool
673     uint64_t reserved_mem_current = 0;
674     // High-water mark of memory reserved by the mempool since last reset
675     uint64_t reserved_mem_peak = 0;
676     // Memory currently in use by the mempool
677     uint64_t used_mem_current = 0;
678     // High-water mark of memory
679     uint64_t used_mem_peak = 0;
680 
681     std::lock_guard<std::mutex> lk(general_mutex);
682 
683     if (devs_initialized_flags[device]) {
684       CUDAGuard g(device);
685 
686       cudaMemPool_t mempool = nullptr;
687       C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
688       C10_CUDA_CHECK(cudaMemPoolGetAttribute(
689           mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current));
690 
691       C10_CUDA_CHECK(cudaMemPoolGetAttribute(
692           mempool, cudaMemPoolAttrReservedMemHigh, &reserved_mem_peak));
693 
694       C10_CUDA_CHECK(cudaMemPoolGetAttribute(
695           mempool, cudaMemPoolAttrUsedMemCurrent, &used_mem_current));
696 
697       C10_CUDA_CHECK(cudaMemPoolGetAttribute(
698           mempool, cudaMemPoolAttrUsedMemHigh, &used_mem_peak));
699     }
700 
701     // Many stat types are specific to the native allocator. We leave these
702     // untouched. Their "struct Stat"s will contain zeroed values.
703     DeviceStats stats;
704 
705     // In the native allocator:
706     // allocated_bytes is the total bytes of blocks that have been malloc()ed
707     // and not yet free()d.
708     // active_bytes is the total bytes of blocks that have been malloc()ed but
709     // not yet released back into a free pool. In other words, it includes all
710     // allocated_bytes, as well as the bytes of "limbo state" blocks had have
711     // already been free()ed but not yet free_block()ed back into a pool due to
712     // outstanding stream_uses.
713     //
714     // Here, in the cudaMallocAsync allocator:
715     // We simply ask the driver's opinion about active memory.
716     // We don't bother distinguishing between allocated_bytes and active_bytes.
717     stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
718         static_cast<int64_t>(used_mem_current);
719     stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
720         static_cast<int64_t>(used_mem_peak);
721     stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
722         static_cast<int64_t>(used_mem_current);
723     stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
724         static_cast<int64_t>(used_mem_peak);
725     stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
726         static_cast<int64_t>(reserved_mem_current);
727     stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
728         static_cast<int64_t>(reserved_mem_peak);
729 
730     return stats;
731   }
732 
resetAccumulatedStatsc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator733   void resetAccumulatedStats(c10::DeviceIndex device) override {
734     assertValidDevice(device);
735     TORCH_WARN_ONCE(
736         "For backend:cudaMallocAsync, resetAccumulatedStats has no effect.");
737   }
738 
resetPeakStatsc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator739   void resetPeakStats(c10::DeviceIndex device) override {
740     assertValidDevice(device);
741 
742     CUDAGuard g(device);
743     cudaMemPool_t mempool = nullptr;
744     C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
745     // Using zero as the reset value is the method recommended by Cuda driver
746     // team. Vivek Kini says:
747     //   "Resetting to zero (which is the only valid value when setting
748     //    ReservedMemHigh) resets it to ReservedMemCurrent inside the driver
749     //   (same goes for UsedMemHigh/UsedMemCurrent)"
750     uint64_t zero = 0;
751     C10_CUDA_CHECK(cudaMemPoolSetAttribute(
752         mempool, cudaMemPoolAttrReservedMemHigh, &zero));
753     C10_CUDA_CHECK(
754         cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
755   }
756 
snapshotc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator757   SnapshotInfo snapshot() override {
758     TORCH_CHECK(
759         false,
760         "Calling snapshot with backend:cudaMallocAsync is not meaningful. "
761         "(For backend:native, snapshot returns a detailed summary of all "
762         "blocks tracked by the allocator, but the cudaMallocAsync backend "
763         "does not track individual blocks.)");
764     // Alternative: TORCH_WARN
765     return {};
766   }
767 
768   // CUDAGraph interactions
beginAllocateToPoolc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator769   void beginAllocateToPool(
770       c10::DeviceIndex device,
771       MempoolId_t mempool_id,
772       std::function<bool(cudaStream_t)>) override {
773     std::lock_guard<std::mutex> lk(general_mutex);
774 
775     TORCH_INTERNAL_ASSERT(capture_free_streams.empty());
776     TORCH_CHECK(
777         !capture_underway,
778         "Only one capture at a time is allowed in a process.")
779     capture_underway = true;
780   }
781 
endAllocateToPoolc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator782   void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id)
783       override {
784     assertValidDevice(device);
785 
786     std::lock_guard<std::mutex> lk(general_mutex);
787 
788     TORCH_CHECK(
789         capture_underway,
790         "CudaMallocAsync::notifyCaptureAboutToEnd called, "
791         "but CudaMallocAsync::capture_underway is false.");
792 
793     auto capture_stream = cuda::getCurrentCUDAStream(device);
794 
795     // See Note [Avoid dangling free streams during CUDA graph capture]
796     for (const auto& free_stream : capture_free_streams) {
797       // cudaEventRecord requires that the input event and stream are on the
798       // same device.
799       CUDAGuard g(free_stream.device);
800 
801       // CUDACachingAllocator.cpp uses raw cuda events, as do we.
802       cudaEvent_t event = nullptr;
803       C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
804       C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
805       C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
806       C10_CUDA_CHECK(cudaEventDestroy(event));
807     }
808 
809     capture_free_streams.clear();
810     TORCH_CHECK(
811         capture_underway,
812         "CudaMallocAsync::notifyCaptureEnded called, "
813         "but CudaMallocAsync::capture_underway is false.");
814     capture_underway = false;
815   }
816 
releasePoolc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator817   void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
818     // Q: Do we need to do anything special here, like clear long-lived
819     //    pointers created during the original capture (for example,
820     //    tensors intended as the graph's I/O surface) that might still
821     //    be resident in ptr_info?
822     // A: I don't think so.
823     //    Those allocations survived capture because the user held
824     //    explicit tensor references to them,
825     //    Those tensors' destructors will call freeAsync() on each pointer
826     //    when the user is done with them.
827     //    The freeAsync()s will probably incur
828     //    TORCH_WARN("Attempting uncaptured free of a captured allocation..."
829     //    but stale ptrs will not permanently leak into ptr_info.
830   }
831 
raw_allocc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator832   void* raw_alloc(size_t nbytes) override {
833     if (nbytes == 0) {
834       return nullptr;
835     }
836     c10::DeviceIndex device = 0;
837     C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
838     void* r = nullptr;
839     mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
840     return r;
841   }
842 
raw_alloc_with_streamc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator843   void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override {
844     if (nbytes == 0) {
845       return nullptr;
846     }
847     c10::DeviceIndex device = 0;
848     C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
849     void* r = nullptr;
850     mallocAsync(&r, device, nbytes, stream);
851     return r;
852   }
raw_deletec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator853   void raw_delete(void* ptr) override {
854     freeAsync(ptr);
855   }
enablePeerAccessc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator856   void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access)
857       override {
858     // Double-checks allocator backend hasn't changed, which would definitely be
859     // an error. cudaMallocAsync pools are unaffected by
860     // cudaDeviceEnablePeerAccess. We need pool-specific enablement. See
861     // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
862     c10::cuda::CUDAGuard device_guard(dev);
863     cudaMemPool_t mempool = nullptr;
864     C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
865     cudaMemAccessDesc desc = {};
866     desc.location.type = cudaMemLocationTypeDevice;
867     // NOLINTNEXTLINE(bugprone-signed-char-misuse)
868     desc.location.id = dev;
869     desc.flags = cudaMemAccessFlagsProtReadWrite;
870     C10_CUDA_CHECK(cudaMemPoolSetAccess(mempool, &desc, 1 /* numDescs */));
871   }
memcpyAsyncc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator872   cudaError_t memcpyAsync(
873       void* dst,
874       int dstDevice,
875       const void* src,
876       int srcDevice,
877       size_t count,
878       cudaStream_t stream,
879       bool p2p_enabled) override {
880     if (p2p_enabled || dstDevice == srcDevice) {
881       return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream);
882     } else {
883       return cudaMemcpyPeerAsync(dst, dstDevice, src, srcDevice, count, stream);
884     }
885   }
namec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator886   std::string name() override {
887     return "cudaMallocAsync";
888   }
copy_datac10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator889   void copy_data(void* dest, const void* src, std::size_t count) const final {
890     C10_CUDA_CHECK(
891         cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
892   }
893 };
894 
895 CudaMallocAsyncAllocator device_allocator;
896 
local_raw_delete(void * ptr)897 void local_raw_delete(void* ptr) {
898   freeAsync(ptr);
899 }
allocator()900 CUDAAllocator* allocator() {
901   return &device_allocator;
902 }
903 
904 #else
allocator()905 CUDAAllocator* allocator() {
906   TORCH_CHECK(false, "Cannot use cudaMallocAsyncAllocator with cuda < 11.4.");
907   return nullptr;
908 }
909 
910 #endif
911 
912 } // namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync
913