1*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDACachingAllocator.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAAllocatorConfig.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAException.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAGuard.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/CallOnce.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Gauge.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ScopeExit.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/util/UniqueVoidPtr.h>
12*da0073e9SAndroid Build Coastguard Worker #include <c10/util/flat_hash_map.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/util/hash.h>
14*da0073e9SAndroid Build Coastguard Worker #include <c10/util/llvmMathExtras.h>
15*da0073e9SAndroid Build Coastguard Worker #include <c10/util/static_tracepoint.h>
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
18*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/driver_api.h>
19*da0073e9SAndroid Build Coastguard Worker #include <sys/syscall.h>
20*da0073e9SAndroid Build Coastguard Worker #include <sys/types.h>
21*da0073e9SAndroid Build Coastguard Worker #include <unistd.h>
22*da0073e9SAndroid Build Coastguard Worker #endif
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
25*da0073e9SAndroid Build Coastguard Worker #include <cuda_runtime_api.h>
26*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
27*da0073e9SAndroid Build Coastguard Worker #include <cstddef>
28*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
29*da0073e9SAndroid Build Coastguard Worker #include <deque>
30*da0073e9SAndroid Build Coastguard Worker #include <memory>
31*da0073e9SAndroid Build Coastguard Worker #include <mutex>
32*da0073e9SAndroid Build Coastguard Worker #include <regex>
33*da0073e9SAndroid Build Coastguard Worker #include <set>
34*da0073e9SAndroid Build Coastguard Worker #include <utility>
35*da0073e9SAndroid Build Coastguard Worker #include <vector>
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker TORCH_SDT_DEFINE_SEMAPHORE(malloc)
38*da0073e9SAndroid Build Coastguard Worker TORCH_SDT_DEFINE_SEMAPHORE(free)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker namespace c10 {
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker namespace cuda::CUDACachingAllocator {
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker using namespace c10::CachingDeviceAllocator;
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker // Included here as this is externally used in CUDAAllocatorConfig
49*da0073e9SAndroid Build Coastguard Worker const size_t kLargeBuffer =
50*da0073e9SAndroid Build Coastguard Worker 20971520; // "large" allocations may be packed in 20 MiB blocks
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker namespace Native {
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker //
55*da0073e9SAndroid Build Coastguard Worker // Yet another caching allocator for CUDA device allocations.
56*da0073e9SAndroid Build Coastguard Worker //
57*da0073e9SAndroid Build Coastguard Worker // - Allocations are associated with a stream. Once freed, blocks can be
58*da0073e9SAndroid Build Coastguard Worker // re-allocated on the same stream, but not on any other stream.
59*da0073e9SAndroid Build Coastguard Worker // - The allocator attempts to find the smallest cached block that will fit the
60*da0073e9SAndroid Build Coastguard Worker // requested size. If the block is larger than the requested size, it may be
61*da0073e9SAndroid Build Coastguard Worker // split. If no block is found, the allocator will delegate to cudaMalloc.
62*da0073e9SAndroid Build Coastguard Worker // - If the cudaMalloc fails, the allocator will attempt to free one cached
63*da0073e9SAndroid Build Coastguard Worker // block of sufficient size that is not split and retry the allocation.
64*da0073e9SAndroid Build Coastguard Worker // If this also fails, the allocator will attempt to free all cached blocks
65*da0073e9SAndroid Build Coastguard Worker // that are not split and retry the allocation.
66*da0073e9SAndroid Build Coastguard Worker // - Large (>1MB) and small allocations are stored in separate pools.
67*da0073e9SAndroid Build Coastguard Worker // Small requests are packed into 2MB buffers. Large requests will use the
68*da0073e9SAndroid Build Coastguard Worker // smallest available free block or allocate a new block using cudaMalloc.
69*da0073e9SAndroid Build Coastguard Worker // - To reduce fragmentation, requests between 1MB and 10MB will allocate and
70*da0073e9SAndroid Build Coastguard Worker // split a 20MB block, if no free block of sufficient size is available.
71*da0073e9SAndroid Build Coastguard Worker // - To further reduce fragmentation, blocks >= max_split_size are not allowed
72*da0073e9SAndroid Build Coastguard Worker // to be split. These oversize cached blocks will still satisfy requests
73*da0073e9SAndroid Build Coastguard Worker // within 1MB of the oversize cached block size.
74*da0073e9SAndroid Build Coastguard Worker //
75*da0073e9SAndroid Build Coastguard Worker // With this allocator, allocations and frees should logically be considered
76*da0073e9SAndroid Build Coastguard Worker // "usages" of the memory segment associated with streams, just like kernel
77*da0073e9SAndroid Build Coastguard Worker // launches. The programmer must insert the proper synchronization if memory
78*da0073e9SAndroid Build Coastguard Worker // segments are used from multiple streams.
79*da0073e9SAndroid Build Coastguard Worker //
80*da0073e9SAndroid Build Coastguard Worker // The library provides a recordStream() function to help insert the correct
81*da0073e9SAndroid Build Coastguard Worker // synchronization when allocations are used on multiple streams. This will
82*da0073e9SAndroid Build Coastguard Worker // ensure that the block is not reused before each recorded stream completes
83*da0073e9SAndroid Build Coastguard Worker // work.
84*da0073e9SAndroid Build Coastguard Worker //
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker /**
87*da0073e9SAndroid Build Coastguard Worker * Note [Interaction with CUDA graph capture]
88*da0073e9SAndroid Build Coastguard Worker * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89*da0073e9SAndroid Build Coastguard Worker * Graph capture performs a dry run of a region of execution, freezing all CUDA
90*da0073e9SAndroid Build Coastguard Worker * work (and virtual addresses used during that work) into a "graph." The graph
91*da0073e9SAndroid Build Coastguard Worker * may be "replayed" like a single giant kernel, with greatly reduced CPU
92*da0073e9SAndroid Build Coastguard Worker * overhead as well as modestly improved GPU performance.
93*da0073e9SAndroid Build Coastguard Worker *
94*da0073e9SAndroid Build Coastguard Worker * Because capture bakes in memory addresses, the memory used during capture
95*da0073e9SAndroid Build Coastguard Worker * must be available for the graph to use during replay. DeviceCachingAllocator
96*da0073e9SAndroid Build Coastguard Worker * assigns and frees memory eagerly and dynamically, so if we're not careful
97*da0073e9SAndroid Build Coastguard Worker * about managing graphs' memory, at replay time those memory addresses could be
98*da0073e9SAndroid Build Coastguard Worker * used by other tensors.
99*da0073e9SAndroid Build Coastguard Worker *
100*da0073e9SAndroid Build Coastguard Worker * To guarantee a graph's baked in addresses are safe to reuse in replay,
101*da0073e9SAndroid Build Coastguard Worker * DeviceAllocator satisfies allocations from a graph-private memory pool during
102*da0073e9SAndroid Build Coastguard Worker * capture, and doesn't begin cudaFreeing those addresses until the graph is
103*da0073e9SAndroid Build Coastguard Worker * destroyed.
104*da0073e9SAndroid Build Coastguard Worker *
105*da0073e9SAndroid Build Coastguard Worker * Within the private pool, allocations are freed and reassigned as usual during
106*da0073e9SAndroid Build Coastguard Worker * capture. Memory regions will be used in a consistent order during replay. So
107*da0073e9SAndroid Build Coastguard Worker * a private pool doesn't use memory more wastefully than the default pools
108*da0073e9SAndroid Build Coastguard Worker * during capture, but it does reserve its high-water mark of used memory away
109*da0073e9SAndroid Build Coastguard Worker * from the default pools as long as the capture(s) it served survive
110*da0073e9SAndroid Build Coastguard Worker * (regardless whether those captures are idle or replaying).
111*da0073e9SAndroid Build Coastguard Worker *
112*da0073e9SAndroid Build Coastguard Worker * CUDAGraph's requests for private pools are mediated by
113*da0073e9SAndroid Build Coastguard Worker * DeviceAllocator::notifyCaptureBegin,
114*da0073e9SAndroid Build Coastguard Worker * notifyCaptureAboutToEnd,
115*da0073e9SAndroid Build Coastguard Worker * notifyCaptureEnded,
116*da0073e9SAndroid Build Coastguard Worker * notifyCaptureDestroy.
117*da0073e9SAndroid Build Coastguard Worker */
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker constexpr size_t kMinBlockSize =
120*da0073e9SAndroid Build Coastguard Worker 512; // all sizes are rounded to at least 512 bytes
121*da0073e9SAndroid Build Coastguard Worker constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
122*da0073e9SAndroid Build Coastguard Worker constexpr size_t kSmallBuffer =
123*da0073e9SAndroid Build Coastguard Worker 2097152; // "small" allocations are packed in 2 MiB blocks
124*da0073e9SAndroid Build Coastguard Worker constexpr size_t kMinLargeAlloc =
125*da0073e9SAndroid Build Coastguard Worker 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
126*da0073e9SAndroid Build Coastguard Worker constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker char SHAREABLE_HANDLE_VERSION = 1;
129*da0073e9SAndroid Build Coastguard Worker enum ShareableHandleType : char {
130*da0073e9SAndroid Build Coastguard Worker SHAREABLE_CUDA_MALLOC = 'c',
131*da0073e9SAndroid Build Coastguard Worker SHAREABLE_CUDA_EXPANDABLE_SEGMENT = 'e'
132*da0073e9SAndroid Build Coastguard Worker };
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker namespace {
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
137*da0073e9SAndroid Build Coastguard Worker
decrease_stat_array(StatArray & stat_array,size_t amount,const StatTypes & stat_types)138*da0073e9SAndroid Build Coastguard Worker void decrease_stat_array(
139*da0073e9SAndroid Build Coastguard Worker StatArray& stat_array,
140*da0073e9SAndroid Build Coastguard Worker size_t amount,
141*da0073e9SAndroid Build Coastguard Worker const StatTypes& stat_types) {
142*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(
143*da0073e9SAndroid Build Coastguard Worker stat_types, [&stat_array, amount](size_t stat_type) {
144*da0073e9SAndroid Build Coastguard Worker stat_array[stat_type].decrease(amount);
145*da0073e9SAndroid Build Coastguard Worker });
146*da0073e9SAndroid Build Coastguard Worker }
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker struct Block;
149*da0073e9SAndroid Build Coastguard Worker struct PrivatePool;
150*da0073e9SAndroid Build Coastguard Worker typedef bool (*Comparison)(const Block*, const Block*);
151*da0073e9SAndroid Build Coastguard Worker static bool BlockComparatorSize(const Block* a, const Block* b);
152*da0073e9SAndroid Build Coastguard Worker static bool BlockComparatorAddress(const Block* a, const Block* b);
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker struct BlockPool {
BlockPoolc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::BlockPool155*da0073e9SAndroid Build Coastguard Worker BlockPool(bool small, PrivatePool* private_pool = nullptr)
156*da0073e9SAndroid Build Coastguard Worker : blocks(BlockComparatorSize),
157*da0073e9SAndroid Build Coastguard Worker unmapped(BlockComparatorAddress),
158*da0073e9SAndroid Build Coastguard Worker is_small(small),
159*da0073e9SAndroid Build Coastguard Worker owner_PrivatePool(private_pool) {}
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker // Do not insert a Block to blocks directly; use insert_into_blocks(),
162*da0073e9SAndroid Build Coastguard Worker // instead.
163*da0073e9SAndroid Build Coastguard Worker std::set<Block*, Comparison> blocks;
164*da0073e9SAndroid Build Coastguard Worker std::set<Block*, Comparison> unmapped;
165*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
166*da0073e9SAndroid Build Coastguard Worker const bool is_small;
167*da0073e9SAndroid Build Coastguard Worker PrivatePool* owner_PrivatePool;
168*da0073e9SAndroid Build Coastguard Worker int64_t get_free_blocks_call_count{0};
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker // Add a Block into blocks set with updating gc counter.
171*da0073e9SAndroid Build Coastguard Worker std::pair<std::set<Block*, Comparison>::iterator, bool> insert_into_blocks(
172*da0073e9SAndroid Build Coastguard Worker Block* block);
173*da0073e9SAndroid Build Coastguard Worker };
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker struct ExpandableSegment;
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker struct Block {
178*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device; // gpu
179*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream; // allocation stream
180*da0073e9SAndroid Build Coastguard Worker stream_set stream_uses; // streams on which the block was used
181*da0073e9SAndroid Build Coastguard Worker size_t size; // block size in bytes
182*da0073e9SAndroid Build Coastguard Worker size_t requested_size; // memory originally requested
183*da0073e9SAndroid Build Coastguard Worker BlockPool* pool{nullptr}; // owning memory pool
184*da0073e9SAndroid Build Coastguard Worker void* ptr{nullptr}; // memory address
185*da0073e9SAndroid Build Coastguard Worker bool allocated{false}; // in-use flag
186*da0073e9SAndroid Build Coastguard Worker bool mapped{true}; // is the virtual address range this Block references
187*da0073e9SAndroid Build Coastguard Worker // backed by physical pages. Always true when
188*da0073e9SAndroid Build Coastguard Worker // expandable_segment_ is null. When false
189*da0073e9SAndroid Build Coastguard Worker // This Block will be aligned to the segment size
190*da0073e9SAndroid Build Coastguard Worker // of its expandable_segment_.
191*da0073e9SAndroid Build Coastguard Worker Block* prev{nullptr}; // prev block if split from a larger allocation
192*da0073e9SAndroid Build Coastguard Worker Block* next{nullptr}; // next block if split from a larger allocation
193*da0073e9SAndroid Build Coastguard Worker int event_count{0}; // number of outstanding CUDA events
194*da0073e9SAndroid Build Coastguard Worker int64_t gc_count_base{0}; // get_free_blocks_call_count when Block is inserted
195*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context_when_allocated;
196*da0073e9SAndroid Build Coastguard Worker // only set for the first block in the segment (when prev == null)
197*da0073e9SAndroid Build Coastguard Worker // this records the frame information when cudaMalloc was called
198*da0073e9SAndroid Build Coastguard Worker // whereas context_when_allocated records the last time we handed this
199*da0073e9SAndroid Build Coastguard Worker // memory out from our cache.
200*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context_when_segment_allocated;
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker ExpandableSegment* expandable_segment_{nullptr};
203*da0073e9SAndroid Build Coastguard Worker
Blockc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::Block204*da0073e9SAndroid Build Coastguard Worker Block(
205*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
206*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
207*da0073e9SAndroid Build Coastguard Worker size_t size,
208*da0073e9SAndroid Build Coastguard Worker BlockPool* pool,
209*da0073e9SAndroid Build Coastguard Worker void* ptr)
210*da0073e9SAndroid Build Coastguard Worker : device(device),
211*da0073e9SAndroid Build Coastguard Worker stream(stream),
212*da0073e9SAndroid Build Coastguard Worker stream_uses(),
213*da0073e9SAndroid Build Coastguard Worker size(size),
214*da0073e9SAndroid Build Coastguard Worker requested_size(0),
215*da0073e9SAndroid Build Coastguard Worker pool(pool),
216*da0073e9SAndroid Build Coastguard Worker ptr(ptr) {}
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker // constructor for search key
Blockc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::Block219*da0073e9SAndroid Build Coastguard Worker Block(c10::DeviceIndex device, cudaStream_t stream, size_t size)
220*da0073e9SAndroid Build Coastguard Worker : device(device),
221*da0073e9SAndroid Build Coastguard Worker stream(stream),
222*da0073e9SAndroid Build Coastguard Worker stream_uses(),
223*da0073e9SAndroid Build Coastguard Worker size(size),
224*da0073e9SAndroid Build Coastguard Worker requested_size(0) {}
225*da0073e9SAndroid Build Coastguard Worker
gc_countc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::Block226*da0073e9SAndroid Build Coastguard Worker size_t gc_count() {
227*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(pool);
228*da0073e9SAndroid Build Coastguard Worker return static_cast<int>(pool->get_free_blocks_call_count - gc_count_base);
229*da0073e9SAndroid Build Coastguard Worker }
230*da0073e9SAndroid Build Coastguard Worker
is_splitc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::Block231*da0073e9SAndroid Build Coastguard Worker bool is_split() const {
232*da0073e9SAndroid Build Coastguard Worker return (prev != nullptr) || (next != nullptr);
233*da0073e9SAndroid Build Coastguard Worker }
splicec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::Block234*da0073e9SAndroid Build Coastguard Worker void splice(Block* before, Block* after) {
235*da0073e9SAndroid Build Coastguard Worker if (before) {
236*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(before->next == after);
237*da0073e9SAndroid Build Coastguard Worker before->next = this;
238*da0073e9SAndroid Build Coastguard Worker }
239*da0073e9SAndroid Build Coastguard Worker prev = before;
240*da0073e9SAndroid Build Coastguard Worker if (after) {
241*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(after->prev == before);
242*da0073e9SAndroid Build Coastguard Worker after->prev = this;
243*da0073e9SAndroid Build Coastguard Worker }
244*da0073e9SAndroid Build Coastguard Worker next = after;
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker };
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker std::pair<std::set<Block*, Comparison>::iterator, bool> BlockPool::
insert_into_blocks(Block * block)249*da0073e9SAndroid Build Coastguard Worker insert_into_blocks(Block* block) {
250*da0073e9SAndroid Build Coastguard Worker block->gc_count_base = get_free_blocks_call_count;
251*da0073e9SAndroid Build Coastguard Worker return blocks.insert(block);
252*da0073e9SAndroid Build Coastguard Worker }
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker struct SegmentRange {
255*da0073e9SAndroid Build Coastguard Worker char* ptr;
256*da0073e9SAndroid Build Coastguard Worker size_t size;
SegmentRangec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::SegmentRange257*da0073e9SAndroid Build Coastguard Worker SegmentRange(void* p, size_t s) : ptr(static_cast<char*>(p)), size(s) {}
258*da0073e9SAndroid Build Coastguard Worker };
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker /*
263*da0073e9SAndroid Build Coastguard Worker Note [Expandable Segments]
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker Rationale
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker For large (>2MB) allocations, the allocator calls cudaMalloc to get allocations
268*da0073e9SAndroid Build Coastguard Worker that are the same size as what the user requests. In the future, parts of these
269*da0073e9SAndroid Build Coastguard Worker allocations can be reused for other requests if they are free. This works well
270*da0073e9SAndroid Build Coastguard Worker when the program makes many requests of exactly the same size or of sizes that
271*da0073e9SAndroid Build Coastguard Worker even multiples of that size. Many deep learning models follow this behavior.
272*da0073e9SAndroid Build Coastguard Worker However, one common exception is when the batch size changes slightly from one
273*da0073e9SAndroid Build Coastguard Worker iteration to the next, e.g. in batched inference. When the program runs
274*da0073e9SAndroid Build Coastguard Worker initially with batch size N, it will make allocations appropriate for that size.
275*da0073e9SAndroid Build Coastguard Worker If in the future, it runs at size N - 1, the existing allocations will still be
276*da0073e9SAndroid Build Coastguard Worker big enough. However, if it runs at size N + 1, then it will have to make new
277*da0073e9SAndroid Build Coastguard Worker allocations that are slightly larger. Not all the tensors are the same size.
278*da0073e9SAndroid Build Coastguard Worker Some might be (N + 1)*A and others (N + 1)*A*B where A and B are some non-batch
279*da0073e9SAndroid Build Coastguard Worker dimensions in the model. Because the allocator reuses existing allocations when
280*da0073e9SAndroid Build Coastguard Worker they are big enough, some number of (N + 1)*A allocations will actually fit in
281*da0073e9SAndroid Build Coastguard Worker the already existing N*B*A segments, though not perfectly. As the model runs it
282*da0073e9SAndroid Build Coastguard Worker will partially fill up all of these segments leaving unusable free slices of
283*da0073e9SAndroid Build Coastguard Worker memory at the end of these segments. The allocator at some point will need to
284*da0073e9SAndroid Build Coastguard Worker cudaMalloc a new (N + 1)*A*B segment. If there is not enough memory, there is
285*da0073e9SAndroid Build Coastguard Worker now no way to recover the slices of memory that are free at the end of existing
286*da0073e9SAndroid Build Coastguard Worker segments. With models 50+ layers deep, this pattern might repeat 50+ times
287*da0073e9SAndroid Build Coastguard Worker creating many slivers.
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker Approach
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker Expandable segments allows the allocator to create a segment initially and then
292*da0073e9SAndroid Build Coastguard Worker expand its size later when more memory is needed. Instead of making one segment
293*da0073e9SAndroid Build Coastguard Worker per allocation, it tries to make one segment (per stream) that grows as
294*da0073e9SAndroid Build Coastguard Worker necessary. Now when the N + 1 case runs, the allocations will tile nicely into
295*da0073e9SAndroid Build Coastguard Worker the one large segment until it fills up. Then more memory is requested and
296*da0073e9SAndroid Build Coastguard Worker appended to the end of the segment. This process does not create as many slivers
297*da0073e9SAndroid Build Coastguard Worker of unusable memory, so it is more likely to succeed at finding this memory.
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker Implementation
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker The expandable_segments:True option is used to enable/disable this behavior. We
302*da0073e9SAndroid Build Coastguard Worker use cuda's low-level memory APIs, which are similar to mmap, to extend the
303*da0073e9SAndroid Build Coastguard Worker memory segments. These APIs separate the allocation of physical memory
304*da0073e9SAndroid Build Coastguard Worker (cuMemCreate) from the allocation of virtual address space (cuMemAddressReserve)
305*da0073e9SAndroid Build Coastguard Worker and the associate between them cuMemMap/cuMemSetAccess.
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker When we allocate a new segment, we allocate enough address space to map
308*da0073e9SAndroid Build Coastguard Worker basically the entire physical memory of the GPU (there is 256TiB of address
309*da0073e9SAndroid Build Coastguard Worker space), but we only map enough physical memory to handle the current amount of
310*da0073e9SAndroid Build Coastguard Worker memory needed by the program. As more is requested, we add more physical memory
311*da0073e9SAndroid Build Coastguard Worker to the segment. This can work at the granularity of GPU pages which are 2MiB
312*da0073e9SAndroid Build Coastguard Worker currently.
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker If we end up out of memory, we can unmap all the memory in our segment
315*da0073e9SAndroid Build Coastguard Worker corresponding to empty physical pages, and return it to CUDA for use at another
316*da0073e9SAndroid Build Coastguard Worker address in the segment or in a segment for a different stream.
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker A current limitation of CUDA's API is that physical memory
319*da0073e9SAndroid Build Coastguard Worker (CUmemGenericAllocationHandle) cannot be split up after it is mapped even if the
320*da0073e9SAndroid Build Coastguard Worker handle holds multiple GPU pages. The cost to map/unmap memory is proportional to
321*da0073e9SAndroid Build Coastguard Worker the number of physical memory chunks that were allocated (mapping 10 separately
322*da0073e9SAndroid Build Coastguard Worker allocated 2MiB pages takes 10x time compared to mapping one 20MiB physical
323*da0073e9SAndroid Build Coastguard Worker allocation of 10 pages). Changing memory mappings also appears to involve at
324*da0073e9SAndroid Build Coastguard Worker least some synchronous actions with the GPU and so should be considered an
325*da0073e9SAndroid Build Coastguard Worker expensive operation. To limit overhead, we use 2MiB pages for our small pool and
326*da0073e9SAndroid Build Coastguard Worker 20MiB pages for our large pool. Initially allocation using expandable_blocks
327*da0073e9SAndroid Build Coastguard Worker will be slower than cudaMalloc, though still in the milliseconds range for
328*da0073e9SAndroid Build Coastguard Worker mapping the entire memory.
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker When mapping new memory to expand the segment, we look for the lowest address at
331*da0073e9SAndroid Build Coastguard Worker which we can fit a new allocation by adding new pages. Normally this will be at
332*da0073e9SAndroid Build Coastguard Worker the end of the block. But if have previously unmapped blocks earlier in the
333*da0073e9SAndroid Build Coastguard Worker segment during an OOM, it will first try to fill in those gaps to keep the
334*da0073e9SAndroid Build Coastguard Worker segment as a single block. By allocating at the lowest address we encourage
335*da0073e9SAndroid Build Coastguard Worker the split up parts of the block to merge into a single block again, reducing
336*da0073e9SAndroid Build Coastguard Worker fragmentation potential.
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker Allocation of blocks in the segment uses the same best-fit heuristics of the
339*da0073e9SAndroid Build Coastguard Worker rest of the allocator.
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker Expandable blocks can be enabled/disabled throughout the run of a program. When
342*da0073e9SAndroid Build Coastguard Worker disabled, the allocator will not put new allocations in an expandable block.
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker Limitations
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker * Slightly slower initial memory allocation speed.
347*da0073e9SAndroid Build Coastguard Worker * IPC of cuda tensors (e.g. for multiprocess dataloaders) is not supported.
348*da0073e9SAndroid Build Coastguard Worker However, it is possible to temporarily disable (expandable_segments:False) the
349*da0073e9SAndroid Build Coastguard Worker bevhavior for allocator tensors that need to be used cross-process.
350*da0073e9SAndroid Build Coastguard Worker * CUDA runtime APIs related to sharing memory across process
351*da0073e9SAndroid Build Coastguard Worker (cudaDeviceEnablePeerAccess) do not work for memory allocated with cuMemMap.
352*da0073e9SAndroid Build Coastguard Worker Instead these mapping have to be done manually. The allocator now has an
353*da0073e9SAndroid Build Coastguard Worker `enablePeerAccess` method to do this.
354*da0073e9SAndroid Build Coastguard Worker */
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker struct ExpandableSegment {
ExpandableSegmentc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment357*da0073e9SAndroid Build Coastguard Worker ExpandableSegment(
358*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
359*da0073e9SAndroid Build Coastguard Worker std::optional<cudaStream_t> stream,
360*da0073e9SAndroid Build Coastguard Worker size_t address_space_size,
361*da0073e9SAndroid Build Coastguard Worker size_t segment_size,
362*da0073e9SAndroid Build Coastguard Worker std::vector<c10::DeviceIndex> peers)
363*da0073e9SAndroid Build Coastguard Worker : device_(device),
364*da0073e9SAndroid Build Coastguard Worker stream_(stream),
365*da0073e9SAndroid Build Coastguard Worker // 2MB for small pool, 20MB for large pool
366*da0073e9SAndroid Build Coastguard Worker segment_size_(segment_size),
367*da0073e9SAndroid Build Coastguard Worker max_handles_(numSegments(address_space_size)),
368*da0073e9SAndroid Build Coastguard Worker peers_(std::move(peers)) {
369*da0073e9SAndroid Build Coastguard Worker cudaDeviceProp prop{};
370*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
371*da0073e9SAndroid Build Coastguard Worker // we allocate enough address space for 1 1/8 the total memory on the GPU.
372*da0073e9SAndroid Build Coastguard Worker // This allows for some cases where we have to unmap pages earlier in the
373*da0073e9SAndroid Build Coastguard Worker // segment to put them at the end.
374*da0073e9SAndroid Build Coastguard Worker max_handles_ = numSegments(prop.totalGlobalMem + prop.totalGlobalMem / 8);
375*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemAddressReserve_(
376*da0073e9SAndroid Build Coastguard Worker &ptr_, segment_size_ * max_handles_, 0ULL, 0, 0ULL));
377*da0073e9SAndroid Build Coastguard Worker }
378*da0073e9SAndroid Build Coastguard Worker // begin must be aligned to segment_size_.
379*da0073e9SAndroid Build Coastguard Worker // returns the actual range mapped, which may be
380*da0073e9SAndroid Build Coastguard Worker // greater than requested if size is not aligned to segment_size_.
381*da0073e9SAndroid Build Coastguard Worker // return size of 0 indicates OOM
mapc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment382*da0073e9SAndroid Build Coastguard Worker SegmentRange map(SegmentRange range) {
383*da0073e9SAndroid Build Coastguard Worker auto begin = segmentLeft(range.ptr);
384*da0073e9SAndroid Build Coastguard Worker auto end = segmentRight(range.ptr + range.size);
385*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
386*da0073e9SAndroid Build Coastguard Worker if (begin == end) {
387*da0073e9SAndroid Build Coastguard Worker return rangeFromHandles(begin, end);
388*da0073e9SAndroid Build Coastguard Worker }
389*da0073e9SAndroid Build Coastguard Worker while (end > handles_.size()) {
390*da0073e9SAndroid Build Coastguard Worker handles_.emplace_back(std::nullopt);
391*da0073e9SAndroid Build Coastguard Worker }
392*da0073e9SAndroid Build Coastguard Worker for (auto i : c10::irange(begin, end)) {
393*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!handles_.at(i));
394*da0073e9SAndroid Build Coastguard Worker CUmemGenericAllocationHandle handle = 0;
395*da0073e9SAndroid Build Coastguard Worker CUmemAllocationProp prop = {};
396*da0073e9SAndroid Build Coastguard Worker prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
397*da0073e9SAndroid Build Coastguard Worker #ifndef FBCODE_CAFFE2
398*da0073e9SAndroid Build Coastguard Worker prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
399*da0073e9SAndroid Build Coastguard Worker #endif
400*da0073e9SAndroid Build Coastguard Worker prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
401*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-signed-char-misuse)
402*da0073e9SAndroid Build Coastguard Worker prop.location.id = static_cast<int>(device_);
403*da0073e9SAndroid Build Coastguard Worker auto status =
404*da0073e9SAndroid Build Coastguard Worker DriverAPI::get()->cuMemCreate_(&handle, segment_size_, &prop, 0);
405*da0073e9SAndroid Build Coastguard Worker if (status == CUDA_ERROR_OUT_OF_MEMORY) {
406*da0073e9SAndroid Build Coastguard Worker for (auto j : c10::irange(begin, i)) {
407*da0073e9SAndroid Build Coastguard Worker auto h = handles_.at(j).value();
408*da0073e9SAndroid Build Coastguard Worker handles_.at(j) = std::nullopt;
409*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle));
410*da0073e9SAndroid Build Coastguard Worker }
411*da0073e9SAndroid Build Coastguard Worker trimHandles();
412*da0073e9SAndroid Build Coastguard Worker return rangeFromHandles(begin, begin);
413*da0073e9SAndroid Build Coastguard Worker }
414*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(status);
415*da0073e9SAndroid Build Coastguard Worker handles_.at(i) = Handle{handle, std::nullopt};
416*da0073e9SAndroid Build Coastguard Worker }
417*da0073e9SAndroid Build Coastguard Worker mapAndSetAccess(begin, end);
418*da0073e9SAndroid Build Coastguard Worker return rangeFromHandles(begin, end);
419*da0073e9SAndroid Build Coastguard Worker }
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker // unmaps all the completely empty segment_size_ segments between
422*da0073e9SAndroid Build Coastguard Worker // [begin, begin + size), returns the offset where the range begin,
423*da0073e9SAndroid Build Coastguard Worker // and the actual size unmapped (multiple of segment_size_)
unmapc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment424*da0073e9SAndroid Build Coastguard Worker SegmentRange unmap(SegmentRange range) {
425*da0073e9SAndroid Build Coastguard Worker auto begin = segmentRight(range.ptr);
426*da0073e9SAndroid Build Coastguard Worker auto end = segmentLeft(range.ptr + range.size);
427*da0073e9SAndroid Build Coastguard Worker if (begin >= end) {
428*da0073e9SAndroid Build Coastguard Worker return SegmentRange{range.ptr, 0};
429*da0073e9SAndroid Build Coastguard Worker }
430*da0073e9SAndroid Build Coastguard Worker unmapHandles(begin, end);
431*da0073e9SAndroid Build Coastguard Worker return rangeFromHandles(begin, end);
432*da0073e9SAndroid Build Coastguard Worker }
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker // Setup IPC sharing for range.
435*da0073e9SAndroid Build Coastguard Worker // Returns the (larger) range that was actually shared.
436*da0073e9SAndroid Build Coastguard Worker // Serializes data to std::ostream that can be passed to the
437*da0073e9SAndroid Build Coastguard Worker // other process, and then restored as an exapandable segment
438*da0073e9SAndroid Build Coastguard Worker // via ExpandableSegment::fromShared(istream);
sharec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment439*da0073e9SAndroid Build Coastguard Worker SegmentRange share(SegmentRange range, std::ostream& buf) {
440*da0073e9SAndroid Build Coastguard Worker auto begin = segmentLeft(range.ptr);
441*da0073e9SAndroid Build Coastguard Worker auto end = segmentRight(range.ptr + range.size);
442*da0073e9SAndroid Build Coastguard Worker ShareHeader header{getpid(), segment_size_, end - begin};
443*da0073e9SAndroid Build Coastguard Worker buf.write((const char*)&header, sizeof(ShareHeader));
444*da0073e9SAndroid Build Coastguard Worker for (auto i : c10::irange(begin, end)) {
445*da0073e9SAndroid Build Coastguard Worker auto& handle = handles_.at(i).value();
446*da0073e9SAndroid Build Coastguard Worker if (!handle.fd) {
447*da0073e9SAndroid Build Coastguard Worker int fd = 0;
448*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_(
449*da0073e9SAndroid Build Coastguard Worker &fd, handle.handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
450*da0073e9SAndroid Build Coastguard Worker handle.fd = fd;
451*da0073e9SAndroid Build Coastguard Worker }
452*da0073e9SAndroid Build Coastguard Worker int fd = *handle.fd;
453*da0073e9SAndroid Build Coastguard Worker buf.write((const char*)&fd, sizeof(int));
454*da0073e9SAndroid Build Coastguard Worker }
455*da0073e9SAndroid Build Coastguard Worker return rangeFromHandles(begin, end);
456*da0073e9SAndroid Build Coastguard Worker }
457*da0073e9SAndroid Build Coastguard Worker
fromSharedc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment458*da0073e9SAndroid Build Coastguard Worker static std::unique_ptr<ExpandableSegment> fromShared(
459*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
460*da0073e9SAndroid Build Coastguard Worker std::vector<c10::DeviceIndex> peers,
461*da0073e9SAndroid Build Coastguard Worker std::istream& buf) {
462*da0073e9SAndroid Build Coastguard Worker ShareHeader header{};
463*da0073e9SAndroid Build Coastguard Worker buf.read((char*)&header, sizeof(ShareHeader));
464*da0073e9SAndroid Build Coastguard Worker auto segment = std::make_unique<ExpandableSegment>(
465*da0073e9SAndroid Build Coastguard Worker device,
466*da0073e9SAndroid Build Coastguard Worker std::nullopt,
467*da0073e9SAndroid Build Coastguard Worker header.num_handles * header.segment_size,
468*da0073e9SAndroid Build Coastguard Worker header.segment_size,
469*da0073e9SAndroid Build Coastguard Worker std::move(peers));
470*da0073e9SAndroid Build Coastguard Worker // older build setups (e.g. multiwheels) do not have this syscall, added 2020
471*da0073e9SAndroid Build Coastguard Worker // but the kernel on the system might still support it.
472*da0073e9SAndroid Build Coastguard Worker #ifndef SYS_pidfd_open
473*da0073e9SAndroid Build Coastguard Worker #define SYS_pidfd_open 434
474*da0073e9SAndroid Build Coastguard Worker #endif
475*da0073e9SAndroid Build Coastguard Worker #ifndef SYS_pidfd_getfd
476*da0073e9SAndroid Build Coastguard Worker #define SYS_pidfd_getfd 438
477*da0073e9SAndroid Build Coastguard Worker #endif
478*da0073e9SAndroid Build Coastguard Worker auto pidfd = syscall(SYS_pidfd_open, header.pid, 0);
479*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
480*da0073e9SAndroid Build Coastguard Worker pidfd != -1 || errno != ENOSYS,
481*da0073e9SAndroid Build Coastguard Worker "The kernel on this machine does not support the pidfd_open syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. "
482*da0073e9SAndroid Build Coastguard Worker "Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation.");
483*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(pidfd != -1, "pidfd_open:", std::strerror(errno));
484*da0073e9SAndroid Build Coastguard Worker for (auto i : c10::irange(header.num_handles)) {
485*da0073e9SAndroid Build Coastguard Worker (void)i;
486*da0073e9SAndroid Build Coastguard Worker int fd = 0;
487*da0073e9SAndroid Build Coastguard Worker buf.read((char*)&fd, sizeof(int));
488*da0073e9SAndroid Build Coastguard Worker auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0);
489*da0073e9SAndroid Build Coastguard Worker if (myfd == -1) {
490*da0073e9SAndroid Build Coastguard Worker auto err = errno;
491*da0073e9SAndroid Build Coastguard Worker close((int)pidfd);
492*da0073e9SAndroid Build Coastguard Worker for (auto& h : segment->handles_) {
493*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(
494*da0073e9SAndroid Build Coastguard Worker DriverAPI::get()->cuMemRelease_(h.value().handle));
495*da0073e9SAndroid Build Coastguard Worker h = std::nullopt;
496*da0073e9SAndroid Build Coastguard Worker }
497*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
498*da0073e9SAndroid Build Coastguard Worker err != ENOSYS,
499*da0073e9SAndroid Build Coastguard Worker "The kernel on this machine does not support the pidfd_getfd syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. "
500*da0073e9SAndroid Build Coastguard Worker "Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation.");
501*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "pidfd_getfd: ", std::strerror(err));
502*da0073e9SAndroid Build Coastguard Worker }
503*da0073e9SAndroid Build Coastguard Worker CUmemGenericAllocationHandle handle = 0;
504*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_(
505*da0073e9SAndroid Build Coastguard Worker &handle,
506*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(performance-no-int-to-ptr)
507*da0073e9SAndroid Build Coastguard Worker (void*)(uintptr_t)myfd,
508*da0073e9SAndroid Build Coastguard Worker CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
509*da0073e9SAndroid Build Coastguard Worker close((int)myfd);
510*da0073e9SAndroid Build Coastguard Worker segment->handles_.emplace_back(Handle{handle, std::nullopt});
511*da0073e9SAndroid Build Coastguard Worker }
512*da0073e9SAndroid Build Coastguard Worker close((int)pidfd);
513*da0073e9SAndroid Build Coastguard Worker segment->mapAndSetAccess(0, header.num_handles);
514*da0073e9SAndroid Build Coastguard Worker return segment;
515*da0073e9SAndroid Build Coastguard Worker }
516*da0073e9SAndroid Build Coastguard Worker
ptrc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment517*da0073e9SAndroid Build Coastguard Worker char* ptr() const {
518*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(performance-no-int-to-ptr)
519*da0073e9SAndroid Build Coastguard Worker return reinterpret_cast<char*>(ptr_);
520*da0073e9SAndroid Build Coastguard Worker }
521*da0073e9SAndroid Build Coastguard Worker
sizec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment522*da0073e9SAndroid Build Coastguard Worker size_t size() const {
523*da0073e9SAndroid Build Coastguard Worker return max_handles_ * segment_size_;
524*da0073e9SAndroid Build Coastguard Worker }
525*da0073e9SAndroid Build Coastguard Worker
addPeerc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment526*da0073e9SAndroid Build Coastguard Worker void addPeer(c10::DeviceIndex device) {
527*da0073e9SAndroid Build Coastguard Worker peers_.push_back(device);
528*da0073e9SAndroid Build Coastguard Worker forEachAllocatedRange(
529*da0073e9SAndroid Build Coastguard Worker [&](size_t begin, size_t end) { setAccess(device, begin, end); });
530*da0073e9SAndroid Build Coastguard Worker }
531*da0073e9SAndroid Build Coastguard Worker
~ExpandableSegmentc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment532*da0073e9SAndroid Build Coastguard Worker ~ExpandableSegment() {
533*da0073e9SAndroid Build Coastguard Worker forEachAllocatedRange(
534*da0073e9SAndroid Build Coastguard Worker [&](size_t begin, size_t end) { unmapHandles(begin, end); });
535*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemAddressFree_(
536*da0073e9SAndroid Build Coastguard Worker ptr_, segment_size_ * max_handles_));
537*da0073e9SAndroid Build Coastguard Worker }
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker private:
setAccessc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment540*da0073e9SAndroid Build Coastguard Worker void setAccess(c10::DeviceIndex device, size_t begin, size_t end) {
541*da0073e9SAndroid Build Coastguard Worker CUmemAccessDesc desc;
542*da0073e9SAndroid Build Coastguard Worker desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
543*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-signed-char-misuse)
544*da0073e9SAndroid Build Coastguard Worker desc.location.id = static_cast<int>(device);
545*da0073e9SAndroid Build Coastguard Worker desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
546*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemSetAccess_(
547*da0073e9SAndroid Build Coastguard Worker ptr_ + begin * segment_size_, (end - begin) * segment_size_, &desc, 1));
548*da0073e9SAndroid Build Coastguard Worker }
549*da0073e9SAndroid Build Coastguard Worker
mapAndSetAccessc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment550*da0073e9SAndroid Build Coastguard Worker void mapAndSetAccess(size_t begin, size_t end) {
551*da0073e9SAndroid Build Coastguard Worker for (auto i : c10::irange(begin, end)) {
552*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemMap_(
553*da0073e9SAndroid Build Coastguard Worker ptr_ + i * segment_size_,
554*da0073e9SAndroid Build Coastguard Worker segment_size_,
555*da0073e9SAndroid Build Coastguard Worker 0,
556*da0073e9SAndroid Build Coastguard Worker handles_.at(i).value().handle,
557*da0073e9SAndroid Build Coastguard Worker 0ULL));
558*da0073e9SAndroid Build Coastguard Worker }
559*da0073e9SAndroid Build Coastguard Worker setAccess(device_, begin, end);
560*da0073e9SAndroid Build Coastguard Worker for (auto p : peers_) {
561*da0073e9SAndroid Build Coastguard Worker setAccess(p, begin, end);
562*da0073e9SAndroid Build Coastguard Worker }
563*da0073e9SAndroid Build Coastguard Worker }
564*da0073e9SAndroid Build Coastguard Worker
unmapHandlesc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment565*da0073e9SAndroid Build Coastguard Worker void unmapHandles(size_t begin, size_t end) {
566*da0073e9SAndroid Build Coastguard Worker // note: unlike cudaFree, MemUnmap and MemRelease do
567*da0073e9SAndroid Build Coastguard Worker // not appear to synchronize in all cases, so we have to wait for the
568*da0073e9SAndroid Build Coastguard Worker // stream to finish before this memory is truly free.
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker // cannot call c10::cuda::stream_synchronize because
571*da0073e9SAndroid Build Coastguard Worker // it might grab the GIL which can lead to a deadlock
572*da0073e9SAndroid Build Coastguard Worker // Locking order must be GIL -> Allocator Lock
573*da0073e9SAndroid Build Coastguard Worker if (stream_) {
574*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaStreamSynchronize(*stream_));
575*da0073e9SAndroid Build Coastguard Worker } else {
576*da0073e9SAndroid Build Coastguard Worker cuda::CUDAGuard device_guard(device_);
577*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaDeviceSynchronize());
578*da0073e9SAndroid Build Coastguard Worker }
579*da0073e9SAndroid Build Coastguard Worker for (auto i : c10::irange(begin, end)) {
580*da0073e9SAndroid Build Coastguard Worker Handle h = handles_.at(i).value();
581*da0073e9SAndroid Build Coastguard Worker handles_.at(i) = std::nullopt;
582*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemUnmap_(
583*da0073e9SAndroid Build Coastguard Worker ptr_ + segment_size_ * i, segment_size_));
584*da0073e9SAndroid Build Coastguard Worker if (h.fd) {
585*da0073e9SAndroid Build Coastguard Worker close(*h.fd);
586*da0073e9SAndroid Build Coastguard Worker }
587*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle));
588*da0073e9SAndroid Build Coastguard Worker }
589*da0073e9SAndroid Build Coastguard Worker trimHandles();
590*da0073e9SAndroid Build Coastguard Worker }
trimHandlesc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment591*da0073e9SAndroid Build Coastguard Worker void trimHandles() {
592*da0073e9SAndroid Build Coastguard Worker while (!handles_.empty() && !handles_.back()) {
593*da0073e9SAndroid Build Coastguard Worker handles_.pop_back();
594*da0073e9SAndroid Build Coastguard Worker }
595*da0073e9SAndroid Build Coastguard Worker }
forEachAllocatedRangec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment596*da0073e9SAndroid Build Coastguard Worker void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
597*da0073e9SAndroid Build Coastguard Worker size_t start = 0;
598*da0073e9SAndroid Build Coastguard Worker for (auto i : c10::irange(handles_.size())) {
599*da0073e9SAndroid Build Coastguard Worker if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
600*da0073e9SAndroid Build Coastguard Worker start = i;
601*da0073e9SAndroid Build Coastguard Worker }
602*da0073e9SAndroid Build Coastguard Worker if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
603*da0073e9SAndroid Build Coastguard Worker fn(start, i + 1);
604*da0073e9SAndroid Build Coastguard Worker }
605*da0073e9SAndroid Build Coastguard Worker }
606*da0073e9SAndroid Build Coastguard Worker }
numSegmentsc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment607*da0073e9SAndroid Build Coastguard Worker size_t numSegments(size_t size) {
608*da0073e9SAndroid Build Coastguard Worker return (size + segment_size_ - 1) / segment_size_;
609*da0073e9SAndroid Build Coastguard Worker }
segmentLeftc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment610*da0073e9SAndroid Build Coastguard Worker size_t segmentLeft(char* p) {
611*da0073e9SAndroid Build Coastguard Worker auto size = p - ptr();
612*da0073e9SAndroid Build Coastguard Worker return size / segment_size_;
613*da0073e9SAndroid Build Coastguard Worker }
segmentRightc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment614*da0073e9SAndroid Build Coastguard Worker size_t segmentRight(char* p) {
615*da0073e9SAndroid Build Coastguard Worker auto size = p - ptr();
616*da0073e9SAndroid Build Coastguard Worker return numSegments(size);
617*da0073e9SAndroid Build Coastguard Worker }
rangeFromHandlesc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment618*da0073e9SAndroid Build Coastguard Worker SegmentRange rangeFromHandles(size_t begin, size_t end) {
619*da0073e9SAndroid Build Coastguard Worker return SegmentRange(
620*da0073e9SAndroid Build Coastguard Worker ptr() + segment_size_ * begin, segment_size_ * (end - begin));
621*da0073e9SAndroid Build Coastguard Worker }
622*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device_;
623*da0073e9SAndroid Build Coastguard Worker std::optional<cudaStream_t> stream_;
624*da0073e9SAndroid Build Coastguard Worker CUdeviceptr ptr_{};
625*da0073e9SAndroid Build Coastguard Worker size_t segment_size_;
626*da0073e9SAndroid Build Coastguard Worker size_t max_handles_;
627*da0073e9SAndroid Build Coastguard Worker struct Handle {
628*da0073e9SAndroid Build Coastguard Worker CUmemGenericAllocationHandle handle;
629*da0073e9SAndroid Build Coastguard Worker std::optional<int> fd;
630*da0073e9SAndroid Build Coastguard Worker };
631*da0073e9SAndroid Build Coastguard Worker struct ShareHeader {
632*da0073e9SAndroid Build Coastguard Worker pid_t pid;
633*da0073e9SAndroid Build Coastguard Worker size_t segment_size;
634*da0073e9SAndroid Build Coastguard Worker size_t num_handles;
635*da0073e9SAndroid Build Coastguard Worker };
636*da0073e9SAndroid Build Coastguard Worker std::vector<std::optional<Handle>> handles_;
637*da0073e9SAndroid Build Coastguard Worker // devices on which this memory should be mapped in addition
638*da0073e9SAndroid Build Coastguard Worker // to the device where the physical memory lives (device_).
639*da0073e9SAndroid Build Coastguard Worker std::vector<c10::DeviceIndex> peers_;
640*da0073e9SAndroid Build Coastguard Worker };
641*da0073e9SAndroid Build Coastguard Worker #else
642*da0073e9SAndroid Build Coastguard Worker struct ExpandableSegment {
ExpandableSegmentc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment643*da0073e9SAndroid Build Coastguard Worker ExpandableSegment(
644*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
645*da0073e9SAndroid Build Coastguard Worker std::optional<cudaStream_t> stream,
646*da0073e9SAndroid Build Coastguard Worker size_t address_space_size,
647*da0073e9SAndroid Build Coastguard Worker size_t segment_size,
648*da0073e9SAndroid Build Coastguard Worker std::vector<c10::DeviceIndex> peers) {
649*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(false, "expandable segment not supported");
650*da0073e9SAndroid Build Coastguard Worker }
mapc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment651*da0073e9SAndroid Build Coastguard Worker SegmentRange map(SegmentRange range) {
652*da0073e9SAndroid Build Coastguard Worker return SegmentRange(nullptr, 0);
653*da0073e9SAndroid Build Coastguard Worker }
unmapc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment654*da0073e9SAndroid Build Coastguard Worker SegmentRange unmap(SegmentRange range) {
655*da0073e9SAndroid Build Coastguard Worker return SegmentRange(nullptr, 0);
656*da0073e9SAndroid Build Coastguard Worker }
sharec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment657*da0073e9SAndroid Build Coastguard Worker SegmentRange share(SegmentRange range, std::ostream& ss) {
658*da0073e9SAndroid Build Coastguard Worker return SegmentRange(nullptr, 0);
659*da0073e9SAndroid Build Coastguard Worker }
fromSharedc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment660*da0073e9SAndroid Build Coastguard Worker static std::unique_ptr<ExpandableSegment> fromShared(
661*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
662*da0073e9SAndroid Build Coastguard Worker std::vector<c10::DeviceIndex> peers,
663*da0073e9SAndroid Build Coastguard Worker std::istream& buf) {
664*da0073e9SAndroid Build Coastguard Worker return {};
665*da0073e9SAndroid Build Coastguard Worker }
ptrc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment666*da0073e9SAndroid Build Coastguard Worker char* ptr() const {
667*da0073e9SAndroid Build Coastguard Worker return nullptr;
668*da0073e9SAndroid Build Coastguard Worker }
sizec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment669*da0073e9SAndroid Build Coastguard Worker size_t size() const {
670*da0073e9SAndroid Build Coastguard Worker return 0;
671*da0073e9SAndroid Build Coastguard Worker }
addPeerc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::ExpandableSegment672*da0073e9SAndroid Build Coastguard Worker void addPeer(c10::DeviceIndex device) {}
673*da0073e9SAndroid Build Coastguard Worker };
674*da0073e9SAndroid Build Coastguard Worker #endif
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker // BlockState, BlockPoolState, and PrivatePoolState contain the information
677*da0073e9SAndroid Build Coastguard Worker // needed to reconstruct a private pool to a previous state. See note
678*da0073e9SAndroid Build Coastguard Worker // [Checkpointing PrivatePoolState]
679*da0073e9SAndroid Build Coastguard Worker struct BlockState {
680*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = 0;
681*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream = nullptr;
682*da0073e9SAndroid Build Coastguard Worker stream_set stream_uses = {};
683*da0073e9SAndroid Build Coastguard Worker size_t size = 0;
684*da0073e9SAndroid Build Coastguard Worker void* ptr = nullptr;
685*da0073e9SAndroid Build Coastguard Worker bool allocated = false;
686*da0073e9SAndroid Build Coastguard Worker int64_t gc_count_base = 0;
687*da0073e9SAndroid Build Coastguard Worker // maintain invariant that event_count == 0 ;
688*da0073e9SAndroid Build Coastguard Worker // history will be left alone in checkpoint
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker BlockState(Block* block);
691*da0073e9SAndroid Build Coastguard Worker };
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker struct SegmentState {
694*da0073e9SAndroid Build Coastguard Worker std::vector<BlockState> blocks;
695*da0073e9SAndroid Build Coastguard Worker bool is_small = false;
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker SegmentState(Block* head);
698*da0073e9SAndroid Build Coastguard Worker };
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker struct PrivatePoolState : AllocatorState {
701*da0073e9SAndroid Build Coastguard Worker // omitting use_count, and cudaMalloc_count as they remain the same
702*da0073e9SAndroid Build Coastguard Worker MempoolId_t owner_id = {0, 0};
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker std::vector<SegmentState> segments;
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker PrivatePoolState(
707*da0073e9SAndroid Build Coastguard Worker MempoolId_t pool_id,
708*da0073e9SAndroid Build Coastguard Worker const std::vector<Block*>& private_pool_head_blocks);
709*da0073e9SAndroid Build Coastguard Worker };
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker struct RestoreResult {
712*da0073e9SAndroid Build Coastguard Worker std::vector<void*> allocations_freed;
713*da0073e9SAndroid Build Coastguard Worker std::vector<Block*> allocations_created;
714*da0073e9SAndroid Build Coastguard Worker };
715*da0073e9SAndroid Build Coastguard Worker
BlockComparatorSize(const Block * a,const Block * b)716*da0073e9SAndroid Build Coastguard Worker static bool BlockComparatorSize(const Block* a, const Block* b) {
717*da0073e9SAndroid Build Coastguard Worker if (a->stream != b->stream) {
718*da0073e9SAndroid Build Coastguard Worker return (uintptr_t)a->stream < (uintptr_t)b->stream;
719*da0073e9SAndroid Build Coastguard Worker }
720*da0073e9SAndroid Build Coastguard Worker if (a->size != b->size) {
721*da0073e9SAndroid Build Coastguard Worker return a->size < b->size;
722*da0073e9SAndroid Build Coastguard Worker }
723*da0073e9SAndroid Build Coastguard Worker return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
724*da0073e9SAndroid Build Coastguard Worker }
BlockComparatorAddress(const Block * a,const Block * b)725*da0073e9SAndroid Build Coastguard Worker static bool BlockComparatorAddress(const Block* a, const Block* b) {
726*da0073e9SAndroid Build Coastguard Worker if (a->stream != b->stream) {
727*da0073e9SAndroid Build Coastguard Worker return (uintptr_t)a->stream < (uintptr_t)b->stream;
728*da0073e9SAndroid Build Coastguard Worker }
729*da0073e9SAndroid Build Coastguard Worker return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
730*da0073e9SAndroid Build Coastguard Worker }
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker struct AllocParams {
AllocParamsc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::AllocParams733*da0073e9SAndroid Build Coastguard Worker AllocParams(
734*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
735*da0073e9SAndroid Build Coastguard Worker size_t size,
736*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
737*da0073e9SAndroid Build Coastguard Worker BlockPool* pool,
738*da0073e9SAndroid Build Coastguard Worker size_t alloc_size,
739*da0073e9SAndroid Build Coastguard Worker DeviceStats& stats)
740*da0073e9SAndroid Build Coastguard Worker : search_key(device, stream, size), pool(pool), alloc_size(alloc_size) {}
741*da0073e9SAndroid Build Coastguard Worker
devicec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::AllocParams742*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device() const {
743*da0073e9SAndroid Build Coastguard Worker return search_key.device;
744*da0073e9SAndroid Build Coastguard Worker }
streamc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::AllocParams745*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream() const {
746*da0073e9SAndroid Build Coastguard Worker return search_key.stream;
747*da0073e9SAndroid Build Coastguard Worker }
sizec10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::AllocParams748*da0073e9SAndroid Build Coastguard Worker size_t size() const {
749*da0073e9SAndroid Build Coastguard Worker return search_key.size;
750*da0073e9SAndroid Build Coastguard Worker }
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker Block search_key;
753*da0073e9SAndroid Build Coastguard Worker BlockPool* pool;
754*da0073e9SAndroid Build Coastguard Worker size_t alloc_size;
755*da0073e9SAndroid Build Coastguard Worker Block* block{nullptr};
756*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = {false};
757*da0073e9SAndroid Build Coastguard Worker cudaError_t err{cudaSuccess};
758*da0073e9SAndroid Build Coastguard Worker };
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker // Note: cudaEventCreate when concurrently invoked from multiple threads can be
761*da0073e9SAndroid Build Coastguard Worker // very expensive (at least on certain device/driver combinations). Thus, we a)
762*da0073e9SAndroid Build Coastguard Worker // serialize event creation at a per-device level, and b) pool the events to
763*da0073e9SAndroid Build Coastguard Worker // avoid constantly calling cudaEventCreate/cudaEventDestroy. This results in
764*da0073e9SAndroid Build Coastguard Worker // significant improvements in multithreaded workloads with high allocation
765*da0073e9SAndroid Build Coastguard Worker // rates.
766*da0073e9SAndroid Build Coastguard Worker class EventPool {
767*da0073e9SAndroid Build Coastguard Worker public:
768*da0073e9SAndroid Build Coastguard Worker using Event = std::unique_ptr<cudaEvent_t, std::function<void(cudaEvent_t*)>>;
769*da0073e9SAndroid Build Coastguard Worker // TODO: Explicit device count
EventPool()770*da0073e9SAndroid Build Coastguard Worker EventPool() : pools_(at::cuda::device_count()) {}
771*da0073e9SAndroid Build Coastguard Worker
get(c10::DeviceIndex device)772*da0073e9SAndroid Build Coastguard Worker Event get(c10::DeviceIndex device) {
773*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(0 <= device);
774*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(device < static_cast<int>(pools_.size()));
775*da0073e9SAndroid Build Coastguard Worker auto& pool = pools_[device];
776*da0073e9SAndroid Build Coastguard Worker auto destructor = [&pool](cudaEvent_t* event) {
777*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> g(pool.mutex_);
778*da0073e9SAndroid Build Coastguard Worker pool.event_pool_.push_back(std::unique_ptr<cudaEvent_t>(event));
779*da0073e9SAndroid Build Coastguard Worker };
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker // Try to acquire an event from the per-device pool.
782*da0073e9SAndroid Build Coastguard Worker {
783*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> g(pool.mutex_);
784*da0073e9SAndroid Build Coastguard Worker if (!pool.event_pool_.empty()) {
785*da0073e9SAndroid Build Coastguard Worker auto* event = pool.event_pool_.back().release();
786*da0073e9SAndroid Build Coastguard Worker pool.event_pool_.pop_back();
787*da0073e9SAndroid Build Coastguard Worker return Event(event, destructor);
788*da0073e9SAndroid Build Coastguard Worker }
789*da0073e9SAndroid Build Coastguard Worker }
790*da0073e9SAndroid Build Coastguard Worker // otherwise, allocate a new event that will be returned to the pool on
791*da0073e9SAndroid Build Coastguard Worker // destruction.
792*da0073e9SAndroid Build Coastguard Worker auto new_ptr = std::make_unique<cudaEvent_t>();
793*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(
794*da0073e9SAndroid Build Coastguard Worker cudaEventCreateWithFlags(new_ptr.get(), cudaEventDisableTiming));
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker return Event(new_ptr.release(), destructor);
797*da0073e9SAndroid Build Coastguard Worker }
798*da0073e9SAndroid Build Coastguard Worker
empty_cache()799*da0073e9SAndroid Build Coastguard Worker void empty_cache() {
800*da0073e9SAndroid Build Coastguard Worker for (auto& pool : pools_) {
801*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> g(pool.mutex_);
802*da0073e9SAndroid Build Coastguard Worker pool.event_pool_.clear();
803*da0073e9SAndroid Build Coastguard Worker }
804*da0073e9SAndroid Build Coastguard Worker }
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker private:
807*da0073e9SAndroid Build Coastguard Worker struct PerDevicePool {
808*da0073e9SAndroid Build Coastguard Worker alignas(64) std::mutex mutex_;
809*da0073e9SAndroid Build Coastguard Worker std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
810*da0073e9SAndroid Build Coastguard Worker };
811*da0073e9SAndroid Build Coastguard Worker std::vector<PerDevicePool> pools_;
812*da0073e9SAndroid Build Coastguard Worker };
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker // CUDA graphs helper
815*da0073e9SAndroid Build Coastguard Worker struct PrivatePool {
PrivatePoolc10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::PrivatePool816*da0073e9SAndroid Build Coastguard Worker PrivatePool()
817*da0073e9SAndroid Build Coastguard Worker : large_blocks(/*small=*/false, this),
818*da0073e9SAndroid Build Coastguard Worker small_blocks(/*small=*/true, this) {}
819*da0073e9SAndroid Build Coastguard Worker PrivatePool(const PrivatePool&) = delete;
820*da0073e9SAndroid Build Coastguard Worker PrivatePool(PrivatePool&&) = delete;
821*da0073e9SAndroid Build Coastguard Worker PrivatePool& operator=(const PrivatePool&) = delete;
822*da0073e9SAndroid Build Coastguard Worker // Number of live graphs using this pool
823*da0073e9SAndroid Build Coastguard Worker int use_count{1};
824*da0073e9SAndroid Build Coastguard Worker // Number of unfreed cudaMallocs made for this pool. When use_count and
825*da0073e9SAndroid Build Coastguard Worker // cudaMalloc_count drop to zero, we can delete this PrivatePool from
826*da0073e9SAndroid Build Coastguard Worker // graph_pools.
827*da0073e9SAndroid Build Coastguard Worker int cudaMalloc_count{0};
828*da0073e9SAndroid Build Coastguard Worker // Instead of maintaining private BlockPools here, I could stuff all blocks
829*da0073e9SAndroid Build Coastguard Worker // (private or no) into the top-level large_blocks and small_blocks, and
830*da0073e9SAndroid Build Coastguard Worker // distinguish private blocks by adding a "pool id" check above the stream
831*da0073e9SAndroid Build Coastguard Worker // check in BlockComparator. BlockComparator is performance- critical though,
832*da0073e9SAndroid Build Coastguard Worker // I'd rather not add more logic to it.
833*da0073e9SAndroid Build Coastguard Worker BlockPool large_blocks;
834*da0073e9SAndroid Build Coastguard Worker BlockPool small_blocks;
835*da0073e9SAndroid Build Coastguard Worker };
836*da0073e9SAndroid Build Coastguard Worker
BlockState(Block * block)837*da0073e9SAndroid Build Coastguard Worker BlockState::BlockState(Block* block)
838*da0073e9SAndroid Build Coastguard Worker : device(block->device),
839*da0073e9SAndroid Build Coastguard Worker stream(block->stream),
840*da0073e9SAndroid Build Coastguard Worker stream_uses(block->stream_uses),
841*da0073e9SAndroid Build Coastguard Worker size(block->size),
842*da0073e9SAndroid Build Coastguard Worker ptr(block->ptr),
843*da0073e9SAndroid Build Coastguard Worker allocated(block->allocated),
844*da0073e9SAndroid Build Coastguard Worker gc_count_base(block->gc_count_base) {
845*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
846*da0073e9SAndroid Build Coastguard Worker block->event_count == 0,
847*da0073e9SAndroid Build Coastguard Worker "Events should have synchronized when checkpointing block");
848*da0073e9SAndroid Build Coastguard Worker };
849*da0073e9SAndroid Build Coastguard Worker
SegmentState(Block * head)850*da0073e9SAndroid Build Coastguard Worker SegmentState::SegmentState(Block* head) {
851*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(head->prev == nullptr && head->pool != nullptr);
852*da0073e9SAndroid Build Coastguard Worker is_small = head->pool->is_small;
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker for (Block* curr = head; curr != nullptr; curr = curr->next) {
855*da0073e9SAndroid Build Coastguard Worker blocks.emplace_back(curr);
856*da0073e9SAndroid Build Coastguard Worker }
857*da0073e9SAndroid Build Coastguard Worker }
858*da0073e9SAndroid Build Coastguard Worker
PrivatePoolState(MempoolId_t pool_id,const std::vector<Block * > & private_pool_head_blocks)859*da0073e9SAndroid Build Coastguard Worker PrivatePoolState::PrivatePoolState(
860*da0073e9SAndroid Build Coastguard Worker MempoolId_t pool_id,
861*da0073e9SAndroid Build Coastguard Worker const std::vector<Block*>& private_pool_head_blocks)
862*da0073e9SAndroid Build Coastguard Worker : owner_id(std::move(pool_id)) {
863*da0073e9SAndroid Build Coastguard Worker for (Block* head : private_pool_head_blocks) {
864*da0073e9SAndroid Build Coastguard Worker segments.emplace_back(head);
865*da0073e9SAndroid Build Coastguard Worker }
866*da0073e9SAndroid Build Coastguard Worker }
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker struct MempoolIdHash {
operator ()c10::cuda::CUDACachingAllocator::Native::__anon8687b97a0111::MempoolIdHash869*da0073e9SAndroid Build Coastguard Worker std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
870*da0073e9SAndroid Build Coastguard Worker return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
871*da0073e9SAndroid Build Coastguard Worker }
872*da0073e9SAndroid Build Coastguard Worker };
873*da0073e9SAndroid Build Coastguard Worker
cudaMallocMaybeCapturing(void ** p,size_t size)874*da0073e9SAndroid Build Coastguard Worker cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
875*da0073e9SAndroid Build Coastguard Worker if (at::cuda::currentStreamCaptureStatusMayInitCtx() ==
876*da0073e9SAndroid Build Coastguard Worker at::cuda::CaptureStatus::None) {
877*da0073e9SAndroid Build Coastguard Worker return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size));
878*da0073e9SAndroid Build Coastguard Worker } else {
879*da0073e9SAndroid Build Coastguard Worker // It's ok to capture cudaMallocs, as long as we never cudaFree those
880*da0073e9SAndroid Build Coastguard Worker // addresses before replay.
881*da0073e9SAndroid Build Coastguard Worker // Capturing cudaMalloc behaves nicely: it gives the graph new VA,
882*da0073e9SAndroid Build Coastguard Worker // but is ignored (won't leakily allocate new memory) in replays.
883*da0073e9SAndroid Build Coastguard Worker at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed};
884*da0073e9SAndroid Build Coastguard Worker return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size));
885*da0073e9SAndroid Build Coastguard Worker }
886*da0073e9SAndroid Build Coastguard Worker }
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker template <class T>
889*da0073e9SAndroid Build Coastguard Worker class RingBuffer {
890*da0073e9SAndroid Build Coastguard Worker public:
RingBuffer()891*da0073e9SAndroid Build Coastguard Worker RingBuffer() {
892*da0073e9SAndroid Build Coastguard Worker // alloc_trace is a pointer because we need to intentionally
893*da0073e9SAndroid Build Coastguard Worker // leak this on deallocation it can hold references to Python
894*da0073e9SAndroid Build Coastguard Worker // state which will already be destroyed when we are in exit handlers
895*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
896*da0073e9SAndroid Build Coastguard Worker alloc_trace = new std::vector<T>();
897*da0073e9SAndroid Build Coastguard Worker }
898*da0073e9SAndroid Build Coastguard Worker
setMaxEntries(size_t size)899*da0073e9SAndroid Build Coastguard Worker void setMaxEntries(size_t size) {
900*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lk(alloc_trace_lock);
901*da0073e9SAndroid Build Coastguard Worker alloc_trace_max_entries_ = std::max(size_t(1), size);
902*da0073e9SAndroid Build Coastguard Worker }
903*da0073e9SAndroid Build Coastguard Worker
insertEntries(const T & entry)904*da0073e9SAndroid Build Coastguard Worker void insertEntries(const T& entry) {
905*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lk(alloc_trace_lock);
906*da0073e9SAndroid Build Coastguard Worker if (alloc_trace->size() < alloc_trace_max_entries_) {
907*da0073e9SAndroid Build Coastguard Worker alloc_trace->emplace_back(entry);
908*da0073e9SAndroid Build Coastguard Worker } else {
909*da0073e9SAndroid Build Coastguard Worker (*alloc_trace)[alloc_trace_next++] = entry;
910*da0073e9SAndroid Build Coastguard Worker if (alloc_trace_next == alloc_trace_max_entries_) {
911*da0073e9SAndroid Build Coastguard Worker alloc_trace_next = 0;
912*da0073e9SAndroid Build Coastguard Worker }
913*da0073e9SAndroid Build Coastguard Worker }
914*da0073e9SAndroid Build Coastguard Worker }
915*da0073e9SAndroid Build Coastguard Worker
getEntries(std::vector<T> & result)916*da0073e9SAndroid Build Coastguard Worker void getEntries(std::vector<T>& result) {
917*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lk(alloc_trace_lock);
918*da0073e9SAndroid Build Coastguard Worker result.reserve(alloc_trace->size());
919*da0073e9SAndroid Build Coastguard Worker result.insert(
920*da0073e9SAndroid Build Coastguard Worker result.end(),
921*da0073e9SAndroid Build Coastguard Worker alloc_trace->begin() +
922*da0073e9SAndroid Build Coastguard Worker static_cast<typename std::vector<T>::difference_type>(
923*da0073e9SAndroid Build Coastguard Worker alloc_trace_next),
924*da0073e9SAndroid Build Coastguard Worker alloc_trace->end());
925*da0073e9SAndroid Build Coastguard Worker result.insert(
926*da0073e9SAndroid Build Coastguard Worker result.end(),
927*da0073e9SAndroid Build Coastguard Worker alloc_trace->begin(),
928*da0073e9SAndroid Build Coastguard Worker alloc_trace->begin() +
929*da0073e9SAndroid Build Coastguard Worker static_cast<typename std::vector<T>::difference_type>(
930*da0073e9SAndroid Build Coastguard Worker alloc_trace_next));
931*da0073e9SAndroid Build Coastguard Worker }
932*da0073e9SAndroid Build Coastguard Worker
clear()933*da0073e9SAndroid Build Coastguard Worker void clear() {
934*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lk(alloc_trace_lock);
935*da0073e9SAndroid Build Coastguard Worker alloc_trace_next = 0;
936*da0073e9SAndroid Build Coastguard Worker alloc_trace->clear();
937*da0073e9SAndroid Build Coastguard Worker }
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker private:
940*da0073e9SAndroid Build Coastguard Worker size_t alloc_trace_max_entries_ = 1;
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker // Both alloc_trace and alloc_trace_next needs to be used
943*da0073e9SAndroid Build Coastguard Worker // under alloc_trace_lock.
944*da0073e9SAndroid Build Coastguard Worker std::mutex alloc_trace_lock;
945*da0073e9SAndroid Build Coastguard Worker size_t alloc_trace_next = 0;
946*da0073e9SAndroid Build Coastguard Worker std::vector<T>*
947*da0073e9SAndroid Build Coastguard Worker alloc_trace; // pointer because we need to intentionally leak this on
948*da0073e9SAndroid Build Coastguard Worker // deallocation it can hold references to Python state which
949*da0073e9SAndroid Build Coastguard Worker // will already be destroyed when we are in exit handlers
950*da0073e9SAndroid Build Coastguard Worker };
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
953*da0073e9SAndroid Build Coastguard Worker } // namespace Native
954*da0073e9SAndroid Build Coastguard Worker
reportProcessMemoryInfo(c10::DeviceIndex device)955*da0073e9SAndroid Build Coastguard Worker static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
956*da0073e9SAndroid Build Coastguard Worker #ifdef PYTORCH_C10_DRIVER_API_SUPPORTED
957*da0073e9SAndroid Build Coastguard Worker void* nvml_handle = DriverAPI::get_nvml_handle();
958*da0073e9SAndroid Build Coastguard Worker if (!nvml_handle) {
959*da0073e9SAndroid Build Coastguard Worker return "";
960*da0073e9SAndroid Build Coastguard Worker }
961*da0073e9SAndroid Build Coastguard Worker static c10::once_flag nvml_init;
962*da0073e9SAndroid Build Coastguard Worker c10::call_once(nvml_init, [] {
963*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_());
964*da0073e9SAndroid Build Coastguard Worker });
965*da0073e9SAndroid Build Coastguard Worker
966*da0073e9SAndroid Build Coastguard Worker cudaDeviceProp prop{};
967*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-c-arrays)
970*da0073e9SAndroid Build Coastguard Worker char pci_id[80];
971*da0073e9SAndroid Build Coastguard Worker snprintf(
972*da0073e9SAndroid Build Coastguard Worker pci_id,
973*da0073e9SAndroid Build Coastguard Worker sizeof(pci_id),
974*da0073e9SAndroid Build Coastguard Worker NVML_DEVICE_PCI_BUS_ID_FMT,
975*da0073e9SAndroid Build Coastguard Worker prop.pciDomainID,
976*da0073e9SAndroid Build Coastguard Worker prop.pciBusID,
977*da0073e9SAndroid Build Coastguard Worker prop.pciDeviceID);
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker nvmlDevice_t nvml_device = nullptr;
980*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
981*da0073e9SAndroid Build Coastguard Worker NVML_SUCCESS ==
982*da0073e9SAndroid Build Coastguard Worker DriverAPI::get()->nvmlDeviceGetHandleByPciBusId_v2_(
983*da0073e9SAndroid Build Coastguard Worker pci_id, &nvml_device));
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker std::vector<nvmlProcessInfo_v1_t> procs(8);
986*da0073e9SAndroid Build Coastguard Worker unsigned int size = procs.size();
987*da0073e9SAndroid Build Coastguard Worker nvmlReturn_t r{};
988*da0073e9SAndroid Build Coastguard Worker while ((r = DriverAPI::get()->nvmlDeviceGetComputeRunningProcesses_(
989*da0073e9SAndroid Build Coastguard Worker nvml_device, &size, procs.data())) ==
990*da0073e9SAndroid Build Coastguard Worker NVML_ERROR_INSUFFICIENT_SIZE) {
991*da0073e9SAndroid Build Coastguard Worker procs.resize(size);
992*da0073e9SAndroid Build Coastguard Worker }
993*da0073e9SAndroid Build Coastguard Worker unsigned int self_pid = getpid();
994*da0073e9SAndroid Build Coastguard Worker std::stringstream ss;
995*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(NVML_SUCCESS == r);
996*da0073e9SAndroid Build Coastguard Worker ss << "";
997*da0073e9SAndroid Build Coastguard Worker for (auto i : c10::irange(size)) {
998*da0073e9SAndroid Build Coastguard Worker auto& proc = procs[i];
999*da0073e9SAndroid Build Coastguard Worker if (self_pid == proc.pid) {
1000*da0073e9SAndroid Build Coastguard Worker ss << "Including non-PyTorch memory, this process";
1001*da0073e9SAndroid Build Coastguard Worker } else {
1002*da0073e9SAndroid Build Coastguard Worker ss << "Process " << proc.pid;
1003*da0073e9SAndroid Build Coastguard Worker }
1004*da0073e9SAndroid Build Coastguard Worker ss << " has " << format_size(proc.usedGpuMemory) << " memory in use. ";
1005*da0073e9SAndroid Build Coastguard Worker }
1006*da0073e9SAndroid Build Coastguard Worker return ss.str();
1007*da0073e9SAndroid Build Coastguard Worker #else
1008*da0073e9SAndroid Build Coastguard Worker return "";
1009*da0073e9SAndroid Build Coastguard Worker #endif
1010*da0073e9SAndroid Build Coastguard Worker }
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker namespace Native {
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker class DeviceCachingAllocator {
1015*da0073e9SAndroid Build Coastguard Worker private:
1016*da0073e9SAndroid Build Coastguard Worker // lock around all operations
1017*da0073e9SAndroid Build Coastguard Worker mutable std::recursive_mutex mutex;
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker // device statistics
1020*da0073e9SAndroid Build Coastguard Worker DeviceStats stats;
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker // unallocated cached blocks larger than 1 MB
1023*da0073e9SAndroid Build Coastguard Worker BlockPool large_blocks;
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker // unallocated cached blocks 1 MB or smaller
1026*da0073e9SAndroid Build Coastguard Worker BlockPool small_blocks;
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker // allocated or in use by a stream. Holds all active allocations,
1029*da0073e9SAndroid Build Coastguard Worker // whether they came from graph_pools or one of the BlockPools above.
1030*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_set<Block*> active_blocks;
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker // captures_underway tracks if we are diverting some
1033*da0073e9SAndroid Build Coastguard Worker // allocations to a specific pool.
1034*da0073e9SAndroid Build Coastguard Worker // Most of the time it's empty, in which case malloc can avoid calling
1035*da0073e9SAndroid Build Coastguard Worker // cudaStreamGetCaptureInfo in the hot path.
1036*da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<MempoolId_t, std::function<bool(cudaStream_t)>>>
1037*da0073e9SAndroid Build Coastguard Worker captures_underway;
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker // See free() for this thing's purpose
1040*da0073e9SAndroid Build Coastguard Worker std::vector<Block*> needs_events_deferred_until_no_capture;
1041*da0073e9SAndroid Build Coastguard Worker // outstanding cuda events
1042*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<
1043*da0073e9SAndroid Build Coastguard Worker cuda::CUDAStream,
1044*da0073e9SAndroid Build Coastguard Worker std::deque<std::pair<EventPool::Event, Block*>>>
1045*da0073e9SAndroid Build Coastguard Worker cuda_events;
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker // record used memory.
1048*da0073e9SAndroid Build Coastguard Worker size_t total_allocated_memory = 0;
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker size_t allowed_memory_maximum = 0;
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker // all live expandable segments
1053*da0073e9SAndroid Build Coastguard Worker std::vector<ExpandableSegment*> expandable_segments_;
1054*da0073e9SAndroid Build Coastguard Worker std::vector<c10::DeviceIndex> devices_with_peer_access_;
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker bool set_fraction = false;
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker bool record_history = false;
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker std::atomic<CreateContextFn> context_recorder_;
1061*da0073e9SAndroid Build Coastguard Worker RecordContext record_context_ = RecordContext::NEVER;
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker // Ring buffer for memory snapshot TraceEntry's
1064*da0073e9SAndroid Build Coastguard Worker RingBuffer<TraceEntry> alloc_buffer;
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker // Members specific to CUDA graphs
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker // Private pools for CUDA graphs
1069*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
1070*da0073e9SAndroid Build Coastguard Worker graph_pools;
1071*da0073e9SAndroid Build Coastguard Worker // Pools no longer referenced by any graph. Their BlockPools are eligible for
1072*da0073e9SAndroid Build Coastguard Worker // free_blocks. Can't be a vector or deque because we might erase entries in
1073*da0073e9SAndroid Build Coastguard Worker // any order. Could be an std::list, but we don't care much, access and
1074*da0073e9SAndroid Build Coastguard Worker // insert/erase are rare.
1075*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
1076*da0073e9SAndroid Build Coastguard Worker graph_pools_freeable;
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker // XXX - maybe we should generalize and have multiple events
1079*da0073e9SAndroid Build Coastguard Worker std::vector<OutOfMemoryObserver> oom_observers_;
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker std::vector<AllocatorTraceTracker> trace_trackers_;
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker // mapping from block to a stream_set, containing streams on which the block
1084*da0073e9SAndroid Build Coastguard Worker // was used while cudagraph capturing
1085*da0073e9SAndroid Build Coastguard Worker std::unordered_map<Block*, stream_set> block_to_cudagraph_stream_uses;
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Worker public:
1088*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
DeviceCachingAllocator()1089*da0073e9SAndroid Build Coastguard Worker DeviceCachingAllocator()
1090*da0073e9SAndroid Build Coastguard Worker : large_blocks(/*small=*/false), small_blocks(/*small=*/true) {
1091*da0073e9SAndroid Build Coastguard Worker stats.max_split_size =
1092*da0073e9SAndroid Build Coastguard Worker static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
1093*da0073e9SAndroid Build Coastguard Worker context_recorder_.store(nullptr);
1094*da0073e9SAndroid Build Coastguard Worker }
1095*da0073e9SAndroid Build Coastguard Worker
recordHistory(bool enabled,CreateContextFn context_recorder,size_t alloc_buffer_max_entries,RecordContext when)1096*da0073e9SAndroid Build Coastguard Worker void recordHistory(
1097*da0073e9SAndroid Build Coastguard Worker bool enabled,
1098*da0073e9SAndroid Build Coastguard Worker CreateContextFn context_recorder,
1099*da0073e9SAndroid Build Coastguard Worker size_t alloc_buffer_max_entries,
1100*da0073e9SAndroid Build Coastguard Worker RecordContext when) {
1101*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::recursive_mutex> lock(mutex);
1102*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(when == RecordContext::NEVER || context_recorder);
1103*da0073e9SAndroid Build Coastguard Worker record_history = enabled;
1104*da0073e9SAndroid Build Coastguard Worker context_recorder_.store(record_history ? context_recorder : nullptr);
1105*da0073e9SAndroid Build Coastguard Worker alloc_buffer.setMaxEntries(alloc_buffer_max_entries);
1106*da0073e9SAndroid Build Coastguard Worker record_context_ = enabled ? when : RecordContext::NEVER;
1107*da0073e9SAndroid Build Coastguard Worker if (!enabled) {
1108*da0073e9SAndroid Build Coastguard Worker alloc_buffer.clear();
1109*da0073e9SAndroid Build Coastguard Worker }
1110*da0073e9SAndroid Build Coastguard Worker }
1111*da0073e9SAndroid Build Coastguard Worker
isHistoryEnabled()1112*da0073e9SAndroid Build Coastguard Worker bool isHistoryEnabled() {
1113*da0073e9SAndroid Build Coastguard Worker return record_history;
1114*da0073e9SAndroid Build Coastguard Worker }
1115*da0073e9SAndroid Build Coastguard Worker
checkPoolLiveAllocations(MempoolId_t mempool_id,const std::unordered_set<void * > & expected_live_allocations)1116*da0073e9SAndroid Build Coastguard Worker bool checkPoolLiveAllocations(
1117*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
1118*da0073e9SAndroid Build Coastguard Worker const std::unordered_set<void*>& expected_live_allocations) {
1119*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::recursive_mutex> lock(mutex);
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker PrivatePool* pool = nullptr;
1122*da0073e9SAndroid Build Coastguard Worker auto pool_it = graph_pools.find(mempool_id);
1123*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(pool_it != graph_pools.end(), "Could not find pool of id");
1124*da0073e9SAndroid Build Coastguard Worker pool = pool_it->second.get();
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(pool != nullptr);
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker size_t allocated_pool_blocks = 0;
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker for (Block* b : active_blocks) {
1131*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(b != nullptr);
1132*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(b->pool != nullptr);
1133*da0073e9SAndroid Build Coastguard Worker if (b->allocated && b->pool->owner_PrivatePool == pool) {
1134*da0073e9SAndroid Build Coastguard Worker if (!expected_live_allocations.count(b->ptr)) {
1135*da0073e9SAndroid Build Coastguard Worker return false;
1136*da0073e9SAndroid Build Coastguard Worker }
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker allocated_pool_blocks += 1;
1139*da0073e9SAndroid Build Coastguard Worker }
1140*da0073e9SAndroid Build Coastguard Worker }
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker return allocated_pool_blocks == expected_live_allocations.size();
1143*da0073e9SAndroid Build Coastguard Worker }
1144*da0073e9SAndroid Build Coastguard Worker
attachOutOfMemoryObserver(OutOfMemoryObserver observer)1145*da0073e9SAndroid Build Coastguard Worker void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
1146*da0073e9SAndroid Build Coastguard Worker oom_observers_.emplace_back(std::move(observer));
1147*da0073e9SAndroid Build Coastguard Worker }
1148*da0073e9SAndroid Build Coastguard Worker
attachAllocatorTraceTracker(AllocatorTraceTracker tracker)1149*da0073e9SAndroid Build Coastguard Worker void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
1150*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::recursive_mutex> lock(mutex);
1151*da0073e9SAndroid Build Coastguard Worker trace_trackers_.emplace_back(std::move(tracker));
1152*da0073e9SAndroid Build Coastguard Worker }
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker // Must be called outside of `mutex` or deadlocks are possible with Python
maybeGatherContext(RecordContext level)1155*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> maybeGatherContext(RecordContext level) {
1156*da0073e9SAndroid Build Coastguard Worker if (record_context_ < level) {
1157*da0073e9SAndroid Build Coastguard Worker return nullptr;
1158*da0073e9SAndroid Build Coastguard Worker }
1159*da0073e9SAndroid Build Coastguard Worker return context_recorder_.load()();
1160*da0073e9SAndroid Build Coastguard Worker }
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker // All public methods (except the above) acquire the allocator mutex.
1163*da0073e9SAndroid Build Coastguard Worker // Thus, do not call a public method from another public method.
1164*da0073e9SAndroid Build Coastguard Worker
malloc(c10::DeviceIndex device,size_t orig_size,cudaStream_t stream)1165*da0073e9SAndroid Build Coastguard Worker Block* malloc(
1166*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
1167*da0073e9SAndroid Build Coastguard Worker size_t orig_size,
1168*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream) {
1169*da0073e9SAndroid Build Coastguard Worker // done outside the lock because we don't know what locks the recorder needs
1170*da0073e9SAndroid Build Coastguard Worker // to have...
1171*da0073e9SAndroid Build Coastguard Worker auto context = maybeGatherContext(RecordContext::STATE);
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::recursive_mutex> lock(mutex);
1174*da0073e9SAndroid Build Coastguard Worker
1175*da0073e9SAndroid Build Coastguard Worker if (C10_LIKELY(captures_underway.empty())) {
1176*da0073e9SAndroid Build Coastguard Worker // Processes end-of-life events for outstanding allocations used on
1177*da0073e9SAndroid Build Coastguard Worker // multiple streams (checks if their GPU-side uses are complete and
1178*da0073e9SAndroid Build Coastguard Worker // recycles their memory if so)
1179*da0073e9SAndroid Build Coastguard Worker //
1180*da0073e9SAndroid Build Coastguard Worker // Q. Why skip process_events if a capture might be underway?
1181*da0073e9SAndroid Build Coastguard Worker // A. process_events involves cudaEventQueries, illegal during CUDA graph
1182*da0073e9SAndroid Build Coastguard Worker // capture.
1183*da0073e9SAndroid Build Coastguard Worker // Dumb simple solution: defer reclaiming these allocations until after
1184*da0073e9SAndroid Build Coastguard Worker // capture. Cross-stream memory use is uncommon, so the deferral's
1185*da0073e9SAndroid Build Coastguard Worker // effect on memory use during capture should be small.
1186*da0073e9SAndroid Build Coastguard Worker process_events(context);
1187*da0073e9SAndroid Build Coastguard Worker }
1188*da0073e9SAndroid Build Coastguard Worker size_t size = round_size(orig_size);
1189*da0073e9SAndroid Build Coastguard Worker auto& pool = get_pool(size, stream);
1190*da0073e9SAndroid Build Coastguard Worker const size_t alloc_size = get_allocation_size(size);
1191*da0073e9SAndroid Build Coastguard Worker AllocParams params(device, size, stream, &pool, alloc_size, stats);
1192*da0073e9SAndroid Build Coastguard Worker params.stat_types = get_stat_types_for_pool(pool);
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker // First, try to get a block from the existing pool.
1195*da0073e9SAndroid Build Coastguard Worker bool block_found =
1196*da0073e9SAndroid Build Coastguard Worker // Search pool
1197*da0073e9SAndroid Build Coastguard Worker get_free_block(params)
1198*da0073e9SAndroid Build Coastguard Worker // Trigger callbacks and retry search
1199*da0073e9SAndroid Build Coastguard Worker || (trigger_free_memory_callbacks(params) && get_free_block(params));
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker // Can't reuse an existing block; try to get a new one.
1202*da0073e9SAndroid Build Coastguard Worker if (!block_found) {
1203*da0073e9SAndroid Build Coastguard Worker // Do garbage collection if the flag is set.
1204*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(
1205*da0073e9SAndroid Build Coastguard Worker set_fraction &&
1206*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
1207*da0073e9SAndroid Build Coastguard Worker garbage_collect_cached_blocks(context);
1208*da0073e9SAndroid Build Coastguard Worker }
1209*da0073e9SAndroid Build Coastguard Worker // Attempt allocate
1210*da0073e9SAndroid Build Coastguard Worker // WARNING: alloc_block may release the allocator lock when calling
1211*da0073e9SAndroid Build Coastguard Worker // cudaMalloc. So far this function has not modified allocator state, but
1212*da0073e9SAndroid Build Coastguard Worker // keep in mind that any observed allocator state may change across calls
1213*da0073e9SAndroid Build Coastguard Worker // to alloc_block since it may release the lock.
1214*da0073e9SAndroid Build Coastguard Worker block_found = alloc_block(params, false, context, lock)
1215*da0073e9SAndroid Build Coastguard Worker // Free enough available cached blocks to satisfy alloc and retry
1216*da0073e9SAndroid Build Coastguard Worker // alloc.
1217*da0073e9SAndroid Build Coastguard Worker || (release_available_cached_blocks(params, context) &&
1218*da0073e9SAndroid Build Coastguard Worker alloc_block(params, false, context, lock))
1219*da0073e9SAndroid Build Coastguard Worker // Free all non-split cached blocks and retry alloc.
1220*da0073e9SAndroid Build Coastguard Worker || (C10_LIKELY(captures_underway.empty()) &&
1221*da0073e9SAndroid Build Coastguard Worker release_cached_blocks(context) &&
1222*da0073e9SAndroid Build Coastguard Worker alloc_block(params, true, context, lock));
1223*da0073e9SAndroid Build Coastguard Worker }
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker if (!block_found) {
1226*da0073e9SAndroid Build Coastguard Worker // For any error code other than cudaErrorMemoryAllocation,
1227*da0073e9SAndroid Build Coastguard Worker // alloc_block should have thrown an exception already.
1228*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker size_t device_free = 0;
1231*da0073e9SAndroid Build Coastguard Worker size_t device_total = 0;
1232*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
1233*da0073e9SAndroid Build Coastguard Worker std::string allowed_info;
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker if (set_fraction) {
1236*da0073e9SAndroid Build Coastguard Worker allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
1237*da0073e9SAndroid Build Coastguard Worker }
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker std::string proc_info = reportProcessMemoryInfo(device);
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker record_trace(
1242*da0073e9SAndroid Build Coastguard Worker TraceEntry::OOM,
1243*da0073e9SAndroid Build Coastguard Worker device_free,
1244*da0073e9SAndroid Build Coastguard Worker params.size(),
1245*da0073e9SAndroid Build Coastguard Worker params.stream(),
1246*da0073e9SAndroid Build Coastguard Worker params.device(),
1247*da0073e9SAndroid Build Coastguard Worker std::move(context));
1248*da0073e9SAndroid Build Coastguard Worker stats.num_ooms += 1;
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker c10::reportOutOfMemoryToProfiler(
1251*da0073e9SAndroid Build Coastguard Worker static_cast<int64_t>(size),
1252*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
1253*da0073e9SAndroid Build Coastguard Worker .current,
1254*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
1255*da0073e9SAndroid Build Coastguard Worker .current,
1256*da0073e9SAndroid Build Coastguard Worker c10::Device(c10::DeviceType::CUDA, device));
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Worker auto allocated_bytes =
1259*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
1260*da0073e9SAndroid Build Coastguard Worker .current;
1261*da0073e9SAndroid Build Coastguard Worker auto reserved_bytes =
1262*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
1263*da0073e9SAndroid Build Coastguard Worker .current;
1264*da0073e9SAndroid Build Coastguard Worker auto observers_local = oom_observers_;
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker size_t allocated_in_private_pools = 0;
1267*da0073e9SAndroid Build Coastguard Worker auto get_size_block = [](const BlockPool& pool) {
1268*da0073e9SAndroid Build Coastguard Worker size_t res = 0;
1269*da0073e9SAndroid Build Coastguard Worker for (const auto& block : pool.blocks) {
1270*da0073e9SAndroid Build Coastguard Worker res += block->size;
1271*da0073e9SAndroid Build Coastguard Worker }
1272*da0073e9SAndroid Build Coastguard Worker return res;
1273*da0073e9SAndroid Build Coastguard Worker };
1274*da0073e9SAndroid Build Coastguard Worker for (const auto& p : graph_pools) {
1275*da0073e9SAndroid Build Coastguard Worker allocated_in_private_pools += get_size_block(p.second->large_blocks);
1276*da0073e9SAndroid Build Coastguard Worker allocated_in_private_pools += get_size_block(p.second->small_blocks);
1277*da0073e9SAndroid Build Coastguard Worker }
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker std::string private_pool_msg;
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker if (allocated_in_private_pools > 0) {
1282*da0073e9SAndroid Build Coastguard Worker private_pool_msg = "with " + format_size(allocated_in_private_pools) +
1283*da0073e9SAndroid Build Coastguard Worker " allocated in private pools (e.g., CUDA Graphs), ";
1284*da0073e9SAndroid Build Coastguard Worker }
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker // Make sure we do not have the device lock before calling our
1287*da0073e9SAndroid Build Coastguard Worker // observers which might need hold the GIL
1288*da0073e9SAndroid Build Coastguard Worker // It is safe to release at this point because will no longer
1289*da0073e9SAndroid Build Coastguard Worker // be reading any allocator state.
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker lock.unlock();
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker for (const auto& obs : observers_local) {
1294*da0073e9SAndroid Build Coastguard Worker obs(device,
1295*da0073e9SAndroid Build Coastguard Worker alloc_size,
1296*da0073e9SAndroid Build Coastguard Worker set_fraction ? allowed_memory_maximum : device_total,
1297*da0073e9SAndroid Build Coastguard Worker device_free);
1298*da0073e9SAndroid Build Coastguard Worker }
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker // "total capacity": total global memory on GPU
1301*da0073e9SAndroid Build Coastguard Worker // "allowed": memory is allowed to use, which set by fraction.
1302*da0073e9SAndroid Build Coastguard Worker // "already allocated": memory allocated by the program using the
1303*da0073e9SAndroid Build Coastguard Worker // caching allocator
1304*da0073e9SAndroid Build Coastguard Worker // "free": free memory as reported by the CUDA API
1305*da0073e9SAndroid Build Coastguard Worker // "cached": memory held by the allocator but not used by the program
1306*da0073e9SAndroid Build Coastguard Worker //
1307*da0073e9SAndroid Build Coastguard Worker // The "allocated" amount does not include memory allocated outside
1308*da0073e9SAndroid Build Coastguard Worker // of the caching allocator, such as memory allocated by other programs
1309*da0073e9SAndroid Build Coastguard Worker // or memory held by the driver.
1310*da0073e9SAndroid Build Coastguard Worker //
1311*da0073e9SAndroid Build Coastguard Worker // The sum of "allocated" + "free" + "cached" may be less than the
1312*da0073e9SAndroid Build Coastguard Worker // total capacity due to memory held by the driver and usage by other
1313*da0073e9SAndroid Build Coastguard Worker // programs.
1314*da0073e9SAndroid Build Coastguard Worker //
1315*da0073e9SAndroid Build Coastguard Worker // Note that at this point free_cached_blocks has already returned all
1316*da0073e9SAndroid Build Coastguard Worker // possible "cached" memory to the driver. The only remaining "cached"
1317*da0073e9SAndroid Build Coastguard Worker // memory is split from a larger block that is partially in-use.
1318*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_WITH(
1319*da0073e9SAndroid Build Coastguard Worker OutOfMemoryError,
1320*da0073e9SAndroid Build Coastguard Worker false,
1321*da0073e9SAndroid Build Coastguard Worker "CUDA out of memory. Tried to allocate ",
1322*da0073e9SAndroid Build Coastguard Worker format_size(alloc_size),
1323*da0073e9SAndroid Build Coastguard Worker ". GPU ",
1324*da0073e9SAndroid Build Coastguard Worker static_cast<int>(device),
1325*da0073e9SAndroid Build Coastguard Worker " has a total capacity of ",
1326*da0073e9SAndroid Build Coastguard Worker format_size(device_total),
1327*da0073e9SAndroid Build Coastguard Worker " of which ",
1328*da0073e9SAndroid Build Coastguard Worker format_size(device_free),
1329*da0073e9SAndroid Build Coastguard Worker " is free. ",
1330*da0073e9SAndroid Build Coastguard Worker proc_info,
1331*da0073e9SAndroid Build Coastguard Worker allowed_info,
1332*da0073e9SAndroid Build Coastguard Worker "Of the allocated memory ",
1333*da0073e9SAndroid Build Coastguard Worker format_size(allocated_bytes + allocated_in_private_pools),
1334*da0073e9SAndroid Build Coastguard Worker " is allocated by PyTorch, ",
1335*da0073e9SAndroid Build Coastguard Worker private_pool_msg,
1336*da0073e9SAndroid Build Coastguard Worker "and ",
1337*da0073e9SAndroid Build Coastguard Worker format_size(
1338*da0073e9SAndroid Build Coastguard Worker reserved_bytes - allocated_bytes - allocated_in_private_pools),
1339*da0073e9SAndroid Build Coastguard Worker " is reserved by PyTorch but unallocated.",
1340*da0073e9SAndroid Build Coastguard Worker " If reserved but unallocated memory is large try setting",
1341*da0073e9SAndroid Build Coastguard Worker " PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid"
1342*da0073e9SAndroid Build Coastguard Worker " fragmentation. See documentation for Memory Management "
1343*da0073e9SAndroid Build Coastguard Worker " (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)");
1344*da0073e9SAndroid Build Coastguard Worker }
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker bool split_remainder = should_split(params.block, params.size());
1347*da0073e9SAndroid Build Coastguard Worker return alloc_found_block(
1348*da0073e9SAndroid Build Coastguard Worker params, orig_size, std::move(context), split_remainder);
1349*da0073e9SAndroid Build Coastguard Worker }
1350*da0073e9SAndroid Build Coastguard Worker
alloc_found_block(const AllocParams & params,size_t orig_size,std::shared_ptr<GatheredContext> context,bool split_remainder)1351*da0073e9SAndroid Build Coastguard Worker Block* alloc_found_block(
1352*da0073e9SAndroid Build Coastguard Worker const AllocParams& params,
1353*da0073e9SAndroid Build Coastguard Worker size_t orig_size,
1354*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context,
1355*da0073e9SAndroid Build Coastguard Worker bool split_remainder) {
1356*da0073e9SAndroid Build Coastguard Worker auto size = params.size();
1357*da0073e9SAndroid Build Coastguard Worker auto device = params.device();
1358*da0073e9SAndroid Build Coastguard Worker auto pool = params.pool;
1359*da0073e9SAndroid Build Coastguard Worker auto stream = params.stream();
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
1362*da0073e9SAndroid Build Coastguard Worker params.err == cudaSuccess && params.block != nullptr &&
1363*da0073e9SAndroid Build Coastguard Worker params.block->ptr != nullptr);
1364*da0073e9SAndroid Build Coastguard Worker Block* block = params.block;
1365*da0073e9SAndroid Build Coastguard Worker Block* remaining = nullptr;
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker const bool already_split = block->is_split();
1368*da0073e9SAndroid Build Coastguard Worker if (split_remainder) {
1369*da0073e9SAndroid Build Coastguard Worker remaining = block;
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker block = new Block(device, stream, size, pool, block->ptr);
1372*da0073e9SAndroid Build Coastguard Worker block->expandable_segment_ = remaining->expandable_segment_;
1373*da0073e9SAndroid Build Coastguard Worker block->prev = remaining->prev;
1374*da0073e9SAndroid Build Coastguard Worker if (block->prev) {
1375*da0073e9SAndroid Build Coastguard Worker block->prev->next = block;
1376*da0073e9SAndroid Build Coastguard Worker }
1377*da0073e9SAndroid Build Coastguard Worker block->next = remaining;
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker remaining->prev = block;
1380*da0073e9SAndroid Build Coastguard Worker remaining->ptr = static_cast<char*>(remaining->ptr) + size;
1381*da0073e9SAndroid Build Coastguard Worker remaining->size -= size;
1382*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1383*da0073e9SAndroid Build Coastguard Worker bool inserted = pool->insert_into_blocks(remaining).second;
1384*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker if (already_split && !block->expandable_segment_) {
1387*da0073e9SAndroid Build Coastguard Worker // An already-split inactive block is being shrunk by size bytes.
1388*da0073e9SAndroid Build Coastguard Worker decrease_stat_array(
1389*da0073e9SAndroid Build Coastguard Worker stats.inactive_split_bytes, block->size, params.stat_types);
1390*da0073e9SAndroid Build Coastguard Worker } else if (!block->expandable_segment_) {
1391*da0073e9SAndroid Build Coastguard Worker // A new split inactive block is being created from a previously unsplit
1392*da0073e9SAndroid Build Coastguard Worker // block, size remaining->size bytes.
1393*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
1394*da0073e9SAndroid Build Coastguard Worker stats.inactive_split_bytes[stat_type].increase(remaining->size);
1395*da0073e9SAndroid Build Coastguard Worker stats.inactive_split[stat_type].increase(1);
1396*da0073e9SAndroid Build Coastguard Worker });
1397*da0073e9SAndroid Build Coastguard Worker }
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker } else if (already_split && !block->expandable_segment_) {
1400*da0073e9SAndroid Build Coastguard Worker // An already-split block is becoming active
1401*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
1402*da0073e9SAndroid Build Coastguard Worker stats.inactive_split_bytes[stat_type].decrease(block->size);
1403*da0073e9SAndroid Build Coastguard Worker stats.inactive_split[stat_type].decrease(1);
1404*da0073e9SAndroid Build Coastguard Worker });
1405*da0073e9SAndroid Build Coastguard Worker }
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker block->allocated = true;
1408*da0073e9SAndroid Build Coastguard Worker block->requested_size = orig_size;
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker block->context_when_allocated = std::move(context);
1411*da0073e9SAndroid Build Coastguard Worker record_trace(
1412*da0073e9SAndroid Build Coastguard Worker TraceEntry::ALLOC,
1413*da0073e9SAndroid Build Coastguard Worker int64_t(block->ptr),
1414*da0073e9SAndroid Build Coastguard Worker orig_size,
1415*da0073e9SAndroid Build Coastguard Worker block->stream,
1416*da0073e9SAndroid Build Coastguard Worker block->device,
1417*da0073e9SAndroid Build Coastguard Worker block->context_when_allocated);
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1420*da0073e9SAndroid Build Coastguard Worker bool inserted = active_blocks.insert(block).second;
1421*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
1424*da0073e9SAndroid Build Coastguard Worker stats.allocation[stat_type].increase(1);
1425*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[stat_type].increase(block->size);
1426*da0073e9SAndroid Build Coastguard Worker stats.active[stat_type].increase(1);
1427*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[stat_type].increase(block->size);
1428*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[stat_type].increase(block->requested_size);
1429*da0073e9SAndroid Build Coastguard Worker });
1430*da0073e9SAndroid Build Coastguard Worker if (block->size >= CUDAAllocatorConfig::max_split_size())
1431*da0073e9SAndroid Build Coastguard Worker stats.oversize_allocations.increase(1);
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker auto allocated_bytes_gauge =
1434*da0073e9SAndroid Build Coastguard Worker STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes);
1435*da0073e9SAndroid Build Coastguard Worker allocated_bytes_gauge.record(
1436*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
1437*da0073e9SAndroid Build Coastguard Worker .current);
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker c10::reportMemoryUsageToProfiler(
1440*da0073e9SAndroid Build Coastguard Worker block->ptr,
1441*da0073e9SAndroid Build Coastguard Worker static_cast<int64_t>(block->size),
1442*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
1443*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
1444*da0073e9SAndroid Build Coastguard Worker c10::Device(c10::DeviceType::CUDA, device));
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker return block;
1447*da0073e9SAndroid Build Coastguard Worker }
1448*da0073e9SAndroid Build Coastguard Worker
free(Block * block)1449*da0073e9SAndroid Build Coastguard Worker void free(Block* block) {
1450*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context =
1451*da0073e9SAndroid Build Coastguard Worker maybeGatherContext(RecordContext::ALL);
1452*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1453*da0073e9SAndroid Build Coastguard Worker
1454*da0073e9SAndroid Build Coastguard Worker block->allocated = false;
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker // following logic might modifying underlaying Block, causing the size
1457*da0073e9SAndroid Build Coastguard Worker // changed. We store ahead for reporting
1458*da0073e9SAndroid Build Coastguard Worker auto orig_block_ptr = block->ptr;
1459*da0073e9SAndroid Build Coastguard Worker auto orig_block_size = block->size;
1460*da0073e9SAndroid Build Coastguard Worker
1461*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(*block->pool);
1462*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
1463*da0073e9SAndroid Build Coastguard Worker stats.allocation[stat_type].decrease(1);
1464*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[stat_type].decrease(block->size);
1465*da0073e9SAndroid Build Coastguard Worker });
1466*da0073e9SAndroid Build Coastguard Worker auto allocated_bytes_gauge =
1467*da0073e9SAndroid Build Coastguard Worker STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes);
1468*da0073e9SAndroid Build Coastguard Worker allocated_bytes_gauge.record(
1469*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
1470*da0073e9SAndroid Build Coastguard Worker .current);
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker record_trace(
1473*da0073e9SAndroid Build Coastguard Worker TraceEntry::FREE_REQUESTED,
1474*da0073e9SAndroid Build Coastguard Worker int64_t(block->ptr),
1475*da0073e9SAndroid Build Coastguard Worker block->requested_size,
1476*da0073e9SAndroid Build Coastguard Worker block->stream,
1477*da0073e9SAndroid Build Coastguard Worker block->device,
1478*da0073e9SAndroid Build Coastguard Worker context ? context : block->context_when_allocated);
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker if (block->size >= CUDAAllocatorConfig::max_split_size())
1481*da0073e9SAndroid Build Coastguard Worker stats.oversize_allocations.decrease(1);
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker if (!block->stream_uses.empty()) {
1484*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(!captures_underway.empty())) {
1485*da0073e9SAndroid Build Coastguard Worker // It's forbidden to cudaEventQuery an event recorded during CUDA graph
1486*da0073e9SAndroid Build Coastguard Worker // capture. We conservatively defer recording end-of-life events until
1487*da0073e9SAndroid Build Coastguard Worker // the next call to process_events() (which won't happen until no
1488*da0073e9SAndroid Build Coastguard Worker // captures are underway)
1489*da0073e9SAndroid Build Coastguard Worker needs_events_deferred_until_no_capture.push_back(block);
1490*da0073e9SAndroid Build Coastguard Worker } else {
1491*da0073e9SAndroid Build Coastguard Worker insert_events(block);
1492*da0073e9SAndroid Build Coastguard Worker }
1493*da0073e9SAndroid Build Coastguard Worker } else {
1494*da0073e9SAndroid Build Coastguard Worker free_block(block, context);
1495*da0073e9SAndroid Build Coastguard Worker }
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker c10::reportMemoryUsageToProfiler(
1498*da0073e9SAndroid Build Coastguard Worker orig_block_ptr,
1499*da0073e9SAndroid Build Coastguard Worker -static_cast<int64_t>(orig_block_size),
1500*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
1501*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
1502*da0073e9SAndroid Build Coastguard Worker c10::Device(c10::DeviceType::CUDA, block->device));
1503*da0073e9SAndroid Build Coastguard Worker }
1504*da0073e9SAndroid Build Coastguard Worker
getBaseAllocation(Block * block,size_t * outSize)1505*da0073e9SAndroid Build Coastguard Worker void* getBaseAllocation(Block* block, size_t* outSize) {
1506*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1507*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
1508*da0073e9SAndroid Build Coastguard Worker !block->expandable_segment_,
1509*da0073e9SAndroid Build Coastguard Worker "Tensors allocated with expandable_segments:True cannot be shared between processes. Consider using expandable_segments:False in data loading workers via torch.cuda.memory._set_allocator_settings('expandable_segments:False')");
1510*da0073e9SAndroid Build Coastguard Worker while (block->prev) {
1511*da0073e9SAndroid Build Coastguard Worker block = block->prev;
1512*da0073e9SAndroid Build Coastguard Worker }
1513*da0073e9SAndroid Build Coastguard Worker void* basePtr = block->ptr;
1514*da0073e9SAndroid Build Coastguard Worker if (outSize) {
1515*da0073e9SAndroid Build Coastguard Worker size_t size = 0;
1516*da0073e9SAndroid Build Coastguard Worker while (block) {
1517*da0073e9SAndroid Build Coastguard Worker size += block->size;
1518*da0073e9SAndroid Build Coastguard Worker block = block->next;
1519*da0073e9SAndroid Build Coastguard Worker }
1520*da0073e9SAndroid Build Coastguard Worker *outSize = size;
1521*da0073e9SAndroid Build Coastguard Worker }
1522*da0073e9SAndroid Build Coastguard Worker return basePtr;
1523*da0073e9SAndroid Build Coastguard Worker }
1524*da0073e9SAndroid Build Coastguard Worker
shareIpcHandle(Block * block)1525*da0073e9SAndroid Build Coastguard Worker ShareableHandle shareIpcHandle(Block* block) {
1526*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1527*da0073e9SAndroid Build Coastguard Worker std::ostringstream ss;
1528*da0073e9SAndroid Build Coastguard Worker ss.put(SHAREABLE_HANDLE_VERSION);
1529*da0073e9SAndroid Build Coastguard Worker ptrdiff_t offset = 0;
1530*da0073e9SAndroid Build Coastguard Worker if (!block->expandable_segment_) {
1531*da0073e9SAndroid Build Coastguard Worker ss.put(SHAREABLE_CUDA_MALLOC);
1532*da0073e9SAndroid Build Coastguard Worker Block* base_block = block;
1533*da0073e9SAndroid Build Coastguard Worker while (base_block->prev) {
1534*da0073e9SAndroid Build Coastguard Worker base_block = base_block->prev;
1535*da0073e9SAndroid Build Coastguard Worker }
1536*da0073e9SAndroid Build Coastguard Worker offset = (char*)block->ptr - (char*)base_block->ptr;
1537*da0073e9SAndroid Build Coastguard Worker cudaIpcMemHandle_t handle;
1538*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_block->ptr));
1539*da0073e9SAndroid Build Coastguard Worker ss.write((char*)&handle, CUDA_IPC_HANDLE_SIZE);
1540*da0073e9SAndroid Build Coastguard Worker } else {
1541*da0073e9SAndroid Build Coastguard Worker ss.put(SHAREABLE_CUDA_EXPANDABLE_SEGMENT);
1542*da0073e9SAndroid Build Coastguard Worker auto full_range = block->expandable_segment_->share(
1543*da0073e9SAndroid Build Coastguard Worker SegmentRange(block->ptr, block->size), ss);
1544*da0073e9SAndroid Build Coastguard Worker offset = (char*)block->ptr - (char*)full_range.ptr;
1545*da0073e9SAndroid Build Coastguard Worker }
1546*da0073e9SAndroid Build Coastguard Worker return ShareableHandle{offset, ss.str()};
1547*da0073e9SAndroid Build Coastguard Worker }
1548*da0073e9SAndroid Build Coastguard Worker
recordStream(Block * block,cuda::CUDAStream stream)1549*da0073e9SAndroid Build Coastguard Worker void recordStream(Block* block, cuda::CUDAStream stream) {
1550*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1551*da0073e9SAndroid Build Coastguard Worker if (stream.stream() == block->stream) {
1552*da0073e9SAndroid Build Coastguard Worker // ignore uses on the allocation stream, since those don't require any
1553*da0073e9SAndroid Build Coastguard Worker // special synchronization
1554*da0073e9SAndroid Build Coastguard Worker return;
1555*da0073e9SAndroid Build Coastguard Worker }
1556*da0073e9SAndroid Build Coastguard Worker block->stream_uses.insert(stream);
1557*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(!captures_underway.empty())) {
1558*da0073e9SAndroid Build Coastguard Worker block_to_cudagraph_stream_uses[block].insert(stream);
1559*da0073e9SAndroid Build Coastguard Worker }
1560*da0073e9SAndroid Build Coastguard Worker }
1561*da0073e9SAndroid Build Coastguard Worker
1562*da0073e9SAndroid Build Coastguard Worker /** set memory fraction to limit maximum allocated memory **/
setMemoryFraction(double fraction)1563*da0073e9SAndroid Build Coastguard Worker void setMemoryFraction(double fraction) {
1564*da0073e9SAndroid Build Coastguard Worker size_t device_free = 0;
1565*da0073e9SAndroid Build Coastguard Worker size_t device_total = 0;
1566*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
1567*da0073e9SAndroid Build Coastguard Worker allowed_memory_maximum =
1568*da0073e9SAndroid Build Coastguard Worker static_cast<size_t>(fraction * static_cast<double>(device_total));
1569*da0073e9SAndroid Build Coastguard Worker set_fraction = true;
1570*da0073e9SAndroid Build Coastguard Worker }
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker /** returns cached blocks to the system allocator **/
emptyCache()1573*da0073e9SAndroid Build Coastguard Worker void emptyCache() {
1574*da0073e9SAndroid Build Coastguard Worker auto context = maybeGatherContext(RecordContext::ALL);
1575*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1576*da0073e9SAndroid Build Coastguard Worker release_cached_blocks(context);
1577*da0073e9SAndroid Build Coastguard Worker }
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker /** Retrieves size of largest unused block held by the memory cache **/
cacheInfo(size_t * largest)1580*da0073e9SAndroid Build Coastguard Worker void cacheInfo(size_t* largest) {
1581*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1582*da0073e9SAndroid Build Coastguard Worker if (*largest ==
1583*da0073e9SAndroid Build Coastguard Worker 0) { // make an initial guess if a zero *largest is passed in
1584*da0073e9SAndroid Build Coastguard Worker size_t tmp_bytes = 0;
1585*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaMemGetInfo(
1586*da0073e9SAndroid Build Coastguard Worker largest, // Use free memory as an optimistic initial guess of *largest
1587*da0073e9SAndroid Build Coastguard Worker &tmp_bytes));
1588*da0073e9SAndroid Build Coastguard Worker }
1589*da0073e9SAndroid Build Coastguard Worker cache_info_aux(large_blocks, largest);
1590*da0073e9SAndroid Build Coastguard Worker cache_info_aux(small_blocks, largest);
1591*da0073e9SAndroid Build Coastguard Worker for (const auto& gp : graph_pools) {
1592*da0073e9SAndroid Build Coastguard Worker cache_info_aux(gp.second->large_blocks, largest);
1593*da0073e9SAndroid Build Coastguard Worker cache_info_aux(gp.second->small_blocks, largest);
1594*da0073e9SAndroid Build Coastguard Worker }
1595*da0073e9SAndroid Build Coastguard Worker }
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker /** Returns a copy of the memory allocator stats **/
getStats()1598*da0073e9SAndroid Build Coastguard Worker DeviceStats getStats() {
1599*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1600*da0073e9SAndroid Build Coastguard Worker return stats;
1601*da0073e9SAndroid Build Coastguard Worker }
1602*da0073e9SAndroid Build Coastguard Worker
1603*da0073e9SAndroid Build Coastguard Worker /** Resets the historical accumulation stats for the device **/
resetAccumulatedStats()1604*da0073e9SAndroid Build Coastguard Worker void resetAccumulatedStats() {
1605*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1606*da0073e9SAndroid Build Coastguard Worker
1607*da0073e9SAndroid Build Coastguard Worker for (const auto statType :
1608*da0073e9SAndroid Build Coastguard Worker c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
1609*da0073e9SAndroid Build Coastguard Worker stats.allocation[statType].reset_accumulated();
1610*da0073e9SAndroid Build Coastguard Worker stats.segment[statType].reset_accumulated();
1611*da0073e9SAndroid Build Coastguard Worker stats.active[statType].reset_accumulated();
1612*da0073e9SAndroid Build Coastguard Worker stats.inactive_split[statType].reset_accumulated();
1613*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[statType].reset_accumulated();
1614*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[statType].reset_accumulated();
1615*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[statType].reset_accumulated();
1616*da0073e9SAndroid Build Coastguard Worker stats.inactive_split_bytes[statType].reset_accumulated();
1617*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[statType].reset_accumulated();
1618*da0073e9SAndroid Build Coastguard Worker }
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker stats.num_alloc_retries = 0;
1621*da0073e9SAndroid Build Coastguard Worker stats.num_ooms = 0;
1622*da0073e9SAndroid Build Coastguard Worker stats.num_sync_all_streams = 0;
1623*da0073e9SAndroid Build Coastguard Worker stats.num_device_alloc = 0;
1624*da0073e9SAndroid Build Coastguard Worker stats.num_device_free = 0;
1625*da0073e9SAndroid Build Coastguard Worker stats.oversize_allocations.reset_accumulated();
1626*da0073e9SAndroid Build Coastguard Worker stats.oversize_segments.reset_accumulated();
1627*da0073e9SAndroid Build Coastguard Worker }
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker /** Resets the historical peak stats for the device **/
resetPeakStats()1630*da0073e9SAndroid Build Coastguard Worker void resetPeakStats() {
1631*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker for (const auto statType :
1634*da0073e9SAndroid Build Coastguard Worker c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
1635*da0073e9SAndroid Build Coastguard Worker stats.allocation[statType].reset_peak();
1636*da0073e9SAndroid Build Coastguard Worker stats.segment[statType].reset_peak();
1637*da0073e9SAndroid Build Coastguard Worker stats.active[statType].reset_peak();
1638*da0073e9SAndroid Build Coastguard Worker stats.inactive_split[statType].reset_peak();
1639*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[statType].reset_peak();
1640*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[statType].reset_peak();
1641*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[statType].reset_peak();
1642*da0073e9SAndroid Build Coastguard Worker stats.inactive_split_bytes[statType].reset_peak();
1643*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[statType].reset_peak();
1644*da0073e9SAndroid Build Coastguard Worker }
1645*da0073e9SAndroid Build Coastguard Worker stats.oversize_allocations.reset_peak();
1646*da0073e9SAndroid Build Coastguard Worker stats.oversize_segments.reset_peak();
1647*da0073e9SAndroid Build Coastguard Worker }
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker /* Checkpoint the state of a private pool necessary to return it to its
1650*da0073e9SAndroid Build Coastguard Worker * current state */
getCheckpointState(MempoolId_t id)1651*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<PrivatePoolState> getCheckpointState(MempoolId_t id) {
1652*da0073e9SAndroid Build Coastguard Worker auto context = maybeGatherContext(RecordContext::ALL);
1653*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1654*da0073e9SAndroid Build Coastguard Worker insert_events_deferred_until_no_capture(context);
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Worker auto pool = graph_pools.find(id);
1657*da0073e9SAndroid Build Coastguard Worker if (pool != graph_pools.end()) {
1658*da0073e9SAndroid Build Coastguard Worker auto private_pool_head_blocks =
1659*da0073e9SAndroid Build Coastguard Worker get_private_pool_head_blocks(pool->second.get());
1660*da0073e9SAndroid Build Coastguard Worker return std::make_unique<PrivatePoolState>(id, private_pool_head_blocks);
1661*da0073e9SAndroid Build Coastguard Worker } else if (graph_pools_freeable.count(id)) {
1662*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "Not expected to checkpoint freeable graph");
1663*da0073e9SAndroid Build Coastguard Worker } else {
1664*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "Could not find pool of id");
1665*da0073e9SAndroid Build Coastguard Worker }
1666*da0073e9SAndroid Build Coastguard Worker }
1667*da0073e9SAndroid Build Coastguard Worker
freeBlocksAllocatedToPool(PrivatePool * private_pool,RestoreResult & rr)1668*da0073e9SAndroid Build Coastguard Worker void freeBlocksAllocatedToPool(PrivatePool* private_pool, RestoreResult& rr) {
1669*da0073e9SAndroid Build Coastguard Worker auto pool_blocks = get_private_pool_head_blocks(private_pool);
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker std::vector<Block*> head_blocks;
1672*da0073e9SAndroid Build Coastguard Worker for (Block* block : pool_blocks) {
1673*da0073e9SAndroid Build Coastguard Worker if (block->prev == nullptr) {
1674*da0073e9SAndroid Build Coastguard Worker head_blocks.push_back(block);
1675*da0073e9SAndroid Build Coastguard Worker }
1676*da0073e9SAndroid Build Coastguard Worker }
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker for (Block* block : head_blocks) {
1679*da0073e9SAndroid Build Coastguard Worker Block* curr = block;
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker while (curr) {
1682*da0073e9SAndroid Build Coastguard Worker // When we free a block, its pointer should never change
1683*da0073e9SAndroid Build Coastguard Worker // only its adjacent blocks, so free, then look at pointer
1684*da0073e9SAndroid Build Coastguard Worker if (curr->allocated) {
1685*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
1686*da0073e9SAndroid Build Coastguard Worker curr->event_count == 0,
1687*da0073e9SAndroid Build Coastguard Worker "Events should have synchronized when setting checkpointed block");
1688*da0073e9SAndroid Build Coastguard Worker rr.allocations_freed.push_back(curr->ptr);
1689*da0073e9SAndroid Build Coastguard Worker free(curr);
1690*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!curr->allocated)
1691*da0073e9SAndroid Build Coastguard Worker }
1692*da0073e9SAndroid Build Coastguard Worker curr = curr->next;
1693*da0073e9SAndroid Build Coastguard Worker }
1694*da0073e9SAndroid Build Coastguard Worker }
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker for (Block* b : get_private_pool_head_blocks(private_pool)) {
1697*da0073e9SAndroid Build Coastguard Worker Block* curr = b;
1698*da0073e9SAndroid Build Coastguard Worker while (curr) {
1699*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!curr->allocated);
1700*da0073e9SAndroid Build Coastguard Worker curr = curr->next;
1701*da0073e9SAndroid Build Coastguard Worker }
1702*da0073e9SAndroid Build Coastguard Worker }
1703*da0073e9SAndroid Build Coastguard Worker }
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker // checkpoint the state of an allocation that may have been
1706*da0073e9SAndroid Build Coastguard Worker // split into multiple blocks
setSegmentStateToCheckpoint(Block * block,SegmentState & segment,const std::shared_ptr<GatheredContext> & context,RestoreResult & rr)1707*da0073e9SAndroid Build Coastguard Worker void setSegmentStateToCheckpoint(
1708*da0073e9SAndroid Build Coastguard Worker Block* block,
1709*da0073e9SAndroid Build Coastguard Worker SegmentState& segment,
1710*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context,
1711*da0073e9SAndroid Build Coastguard Worker RestoreResult& rr) {
1712*da0073e9SAndroid Build Coastguard Worker Block* curr_block = block;
1713*da0073e9SAndroid Build Coastguard Worker Block* last_block = block;
1714*da0073e9SAndroid Build Coastguard Worker
1715*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(block->pool);
1716*da0073e9SAndroid Build Coastguard Worker BlockPool& pool = *block->pool;
1717*da0073e9SAndroid Build Coastguard Worker const auto segment_len = segment.blocks.size();
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker // allocate all blocks in the segment
1720*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < segment_len; ++i) {
1721*da0073e9SAndroid Build Coastguard Worker // The last block in every expandable segment is the remaining amount of
1722*da0073e9SAndroid Build Coastguard Worker // available unmapped virtual address space. We shouldn't change it but
1723*da0073e9SAndroid Build Coastguard Worker // instead check it is correctly formed then skip over allocating it.
1724*da0073e9SAndroid Build Coastguard Worker if (i == segment_len - 1 && curr_block->expandable_segment_) {
1725*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->next == nullptr);
1726*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!curr_block->mapped);
1727*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->allocated == false);
1728*da0073e9SAndroid Build Coastguard Worker continue;
1729*da0073e9SAndroid Build Coastguard Worker }
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker auto& block_state = segment.blocks.at(i);
1732*da0073e9SAndroid Build Coastguard Worker AllocParams params(
1733*da0073e9SAndroid Build Coastguard Worker block_state.device,
1734*da0073e9SAndroid Build Coastguard Worker block_state.size,
1735*da0073e9SAndroid Build Coastguard Worker block_state.stream,
1736*da0073e9SAndroid Build Coastguard Worker &pool,
1737*da0073e9SAndroid Build Coastguard Worker block_state.size,
1738*da0073e9SAndroid Build Coastguard Worker stats);
1739*da0073e9SAndroid Build Coastguard Worker pool.blocks.erase(curr_block);
1740*da0073e9SAndroid Build Coastguard Worker params.block = curr_block;
1741*da0073e9SAndroid Build Coastguard Worker params.stat_types = get_stat_types_for_pool(pool);
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker // splitting a block depends on `max_split_size`, which may have changed
1744*da0073e9SAndroid Build Coastguard Worker // between when checkpoint was taken and now, so we make sure to recreate
1745*da0073e9SAndroid Build Coastguard Worker // the behavior from the checkpoint. Keep splitting as long as there is
1746*da0073e9SAndroid Build Coastguard Worker // space left in the block because the block is already the size of how it
1747*da0073e9SAndroid Build Coastguard Worker // appears in the segment, so any leftover space belongs to the next
1748*da0073e9SAndroid Build Coastguard Worker // block.
1749*da0073e9SAndroid Build Coastguard Worker bool split = curr_block->size > block_state.size;
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker // curr_block will become next pointer if it is split, so reassign with
1752*da0073e9SAndroid Build Coastguard Worker // the returned value
1753*da0073e9SAndroid Build Coastguard Worker curr_block = alloc_found_block(params, block_state.size, context, split);
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->ptr == block_state.ptr);
1756*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->size == block_state.size);
1757*da0073e9SAndroid Build Coastguard Worker
1758*da0073e9SAndroid Build Coastguard Worker last_block = curr_block;
1759*da0073e9SAndroid Build Coastguard Worker curr_block = curr_block->next;
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK((curr_block != nullptr) == ((i + 1) < (segment_len)));
1762*da0073e9SAndroid Build Coastguard Worker }
1763*da0073e9SAndroid Build Coastguard Worker
1764*da0073e9SAndroid Build Coastguard Worker while (last_block->prev) {
1765*da0073e9SAndroid Build Coastguard Worker last_block = last_block->prev;
1766*da0073e9SAndroid Build Coastguard Worker }
1767*da0073e9SAndroid Build Coastguard Worker
1768*da0073e9SAndroid Build Coastguard Worker // free blocks that are not allocated in the checkpoint
1769*da0073e9SAndroid Build Coastguard Worker curr_block = last_block;
1770*da0073e9SAndroid Build Coastguard Worker
1771*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < segment_len; ++i, curr_block = curr_block->next) {
1772*da0073e9SAndroid Build Coastguard Worker if (i == segment_len - 1 && curr_block->expandable_segment_) {
1773*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->next == nullptr);
1774*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!curr_block->mapped);
1775*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->allocated == false);
1776*da0073e9SAndroid Build Coastguard Worker continue;
1777*da0073e9SAndroid Build Coastguard Worker }
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker auto& block_state = segment.blocks.at(i);
1780*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(curr_block != nullptr);
1781*da0073e9SAndroid Build Coastguard Worker
1782*da0073e9SAndroid Build Coastguard Worker if (block_state.allocated) {
1783*da0073e9SAndroid Build Coastguard Worker rr.allocations_created.push_back(curr_block);
1784*da0073e9SAndroid Build Coastguard Worker continue;
1785*da0073e9SAndroid Build Coastguard Worker }
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker free(curr_block);
1788*da0073e9SAndroid Build Coastguard Worker
1789*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->ptr == block_state.ptr);
1790*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->allocated == block_state.allocated);
1791*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(curr_block->size == block_state.size);
1792*da0073e9SAndroid Build Coastguard Worker }
1793*da0073e9SAndroid Build Coastguard Worker }
1794*da0073e9SAndroid Build Coastguard Worker
1795*da0073e9SAndroid Build Coastguard Worker /**
1796*da0073e9SAndroid Build Coastguard Worker * Note [Checkpointing PrivatePoolState]
1797*da0073e9SAndroid Build Coastguard Worker *
1798*da0073e9SAndroid Build Coastguard Worker * Refer above to Note [Interaction with CUDA graph capture]. Allocations made
1799*da0073e9SAndroid Build Coastguard Worker * during graph capture are made from a separate private pool. During graph
1800*da0073e9SAndroid Build Coastguard Worker * capture allocations behave as usual. During graph replay the allocator
1801*da0073e9SAndroid Build Coastguard Worker * state does not change even as new tensors are created. The private pool
1802*da0073e9SAndroid Build Coastguard Worker * will not free its blocks to the main caching allocator until cuda graph use
1803*da0073e9SAndroid Build Coastguard Worker * is finished to prevent an allocation from eager clobbering the memory from
1804*da0073e9SAndroid Build Coastguard Worker * a live but unaccounted for tensor that was created during replay.
1805*da0073e9SAndroid Build Coastguard Worker *
1806*da0073e9SAndroid Build Coastguard Worker * `make_graphed_callables`, a series of separate callables chained in
1807*da0073e9SAndroid Build Coastguard Worker * successive cuda graphs, can share a memory pool because after a cuda graph
1808*da0073e9SAndroid Build Coastguard Worker * recording the allocations in the shared private pool exactly reflect the
1809*da0073e9SAndroid Build Coastguard Worker * tensors that are allocated.
1810*da0073e9SAndroid Build Coastguard Worker *
1811*da0073e9SAndroid Build Coastguard Worker * We would like to extend callable chaining to support a graphed callable
1812*da0073e9SAndroid Build Coastguard Worker * tree. In this scenario, we have a tree of callable chains which will be
1813*da0073e9SAndroid Build Coastguard Worker * captured with cuda graphs. In the diagram below, we have a tree with four
1814*da0073e9SAndroid Build Coastguard Worker * callables, A, B, C, and D. Suppose we have captured, and subsequently
1815*da0073e9SAndroid Build Coastguard Worker * replayed, A, B, and C. Then on a new invocation, we replay A and B, but
1816*da0073e9SAndroid Build Coastguard Worker * would now like to record D. At this point the private pool will not reflect
1817*da0073e9SAndroid Build Coastguard Worker * any of the live tensors created during graph replay. Allocations made
1818*da0073e9SAndroid Build Coastguard Worker * during a new recording with the pool could overwrite those live tensors.
1819*da0073e9SAndroid Build Coastguard Worker *
1820*da0073e9SAndroid Build Coastguard Worker * In order to record a new graph capture after replaying prior callables in
1821*da0073e9SAndroid Build Coastguard Worker * the tree, we need the allocator to reflect the state of the live tensors.
1822*da0073e9SAndroid Build Coastguard Worker * We checkpoint the state of the private pool after each recording, and then
1823*da0073e9SAndroid Build Coastguard Worker * reapply it when we are starting a new recording chain. Additionally, we
1824*da0073e9SAndroid Build Coastguard Worker * must free the allocations for any tensors that died between the end of our
1825*da0073e9SAndroid Build Coastguard Worker * previous graph replaying and our new recording. All of the allocated
1826*da0073e9SAndroid Build Coastguard Worker * segments that existed in the checkpointed state must still exist in the
1827*da0073e9SAndroid Build Coastguard Worker * pool. There may also exist new allocated blocks.
1828*da0073e9SAndroid Build Coastguard Worker * (TODO : link note [live tensors between iterations] when it exists). For
1829*da0073e9SAndroid Build Coastguard Worker * every block that is currently allocated but no allocated in the snapshot,
1830*da0073e9SAndroid Build Coastguard Worker * we will return a pointer to their block.
1831*da0073e9SAndroid Build Coastguard Worker *.
1832*da0073e9SAndroid Build Coastguard Worker *
1833*da0073e9SAndroid Build Coastguard Worker *
1834*da0073e9SAndroid Build Coastguard Worker * ---------------> A ---------------> B ---------------> C
1835*da0073e9SAndroid Build Coastguard Worker * |
1836*da0073e9SAndroid Build Coastguard Worker * |
1837*da0073e9SAndroid Build Coastguard Worker * |
1838*da0073e9SAndroid Build Coastguard Worker * |
1839*da0073e9SAndroid Build Coastguard Worker * ╰ ---------------> D
1840*da0073e9SAndroid Build Coastguard Worker */
setCheckpointPoolState(PrivatePoolState & pps)1841*da0073e9SAndroid Build Coastguard Worker RestoreResult setCheckpointPoolState(PrivatePoolState& pps) {
1842*da0073e9SAndroid Build Coastguard Worker // To reset the caching allocator state we will
1843*da0073e9SAndroid Build Coastguard Worker // - Free all the blocks currently allocated to the pool (see [live tensors
1844*da0073e9SAndroid Build Coastguard Worker // between iterations])
1845*da0073e9SAndroid Build Coastguard Worker // - Allocate all the blocks in a checkpointed segment, whether they are
1846*da0073e9SAndroid Build Coastguard Worker // live or not
1847*da0073e9SAndroid Build Coastguard Worker // - Free the blocks in a checkpointed segment which are not live
1848*da0073e9SAndroid Build Coastguard Worker // This could be optimized, but it nicely reuses exiting apis, and this
1849*da0073e9SAndroid Build Coastguard Worker // is not on the hot path.
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker // following `done outside the lock because we don't know what locks the
1852*da0073e9SAndroid Build Coastguard Worker // recorder needs to have...`
1853*da0073e9SAndroid Build Coastguard Worker
1854*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context =
1855*da0073e9SAndroid Build Coastguard Worker maybeGatherContext(RecordContext::STATE);
1856*da0073e9SAndroid Build Coastguard Worker
1857*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1858*da0073e9SAndroid Build Coastguard Worker
1859*da0073e9SAndroid Build Coastguard Worker RestoreResult rr;
1860*da0073e9SAndroid Build Coastguard Worker
1861*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
1862*da0073e9SAndroid Build Coastguard Worker !graph_pools_freeable.count(pps.owner_id),
1863*da0073e9SAndroid Build Coastguard Worker "Not expected to checkpoint freeable graph");
1864*da0073e9SAndroid Build Coastguard Worker
1865*da0073e9SAndroid Build Coastguard Worker auto pool = graph_pools.find(pps.owner_id);
1866*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(pool != graph_pools.end(), "Could not find private pool id");
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker PrivatePool* private_pool = pool->second.get();
1869*da0073e9SAndroid Build Coastguard Worker
1870*da0073e9SAndroid Build Coastguard Worker freeBlocksAllocatedToPool(private_pool, rr);
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker std::unordered_map<void*, Block*> ptrs_to_blocks;
1873*da0073e9SAndroid Build Coastguard Worker // at this point, all of the blocks should be free, so they will all be in
1874*da0073e9SAndroid Build Coastguard Worker // the block set
1875*da0073e9SAndroid Build Coastguard Worker for (Block* block : private_pool->small_blocks.blocks) {
1876*da0073e9SAndroid Build Coastguard Worker ptrs_to_blocks[block->ptr] = block;
1877*da0073e9SAndroid Build Coastguard Worker }
1878*da0073e9SAndroid Build Coastguard Worker for (Block* block : private_pool->large_blocks.blocks) {
1879*da0073e9SAndroid Build Coastguard Worker ptrs_to_blocks[block->ptr] = block;
1880*da0073e9SAndroid Build Coastguard Worker }
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Worker for (auto& segment : pps.segments) {
1883*da0073e9SAndroid Build Coastguard Worker auto ptr = segment.blocks.at(0).ptr;
1884*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(ptrs_to_blocks.count(ptr), " could not find ", ptr)
1885*da0073e9SAndroid Build Coastguard Worker auto block = ptrs_to_blocks[ptr];
1886*da0073e9SAndroid Build Coastguard Worker
1887*da0073e9SAndroid Build Coastguard Worker setSegmentStateToCheckpoint(block, segment, context, rr);
1888*da0073e9SAndroid Build Coastguard Worker }
1889*da0073e9SAndroid Build Coastguard Worker return rr;
1890*da0073e9SAndroid Build Coastguard Worker }
1891*da0073e9SAndroid Build Coastguard Worker
1892*da0073e9SAndroid Build Coastguard Worker /** Dump a complete snapshot of the memory held by the allocator. Potentially
1893*da0073e9SAndroid Build Coastguard Worker * VERY expensive. **/
snapshot()1894*da0073e9SAndroid Build Coastguard Worker std::vector<SegmentInfo> snapshot() {
1895*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1896*da0073e9SAndroid Build Coastguard Worker
1897*da0073e9SAndroid Build Coastguard Worker std::unordered_map<PrivatePool*, MempoolId_t> pool_to_id;
1898*da0073e9SAndroid Build Coastguard Worker pool_to_id.reserve(graph_pools.size() + graph_pools_freeable.size());
1899*da0073e9SAndroid Build Coastguard Worker for (const auto& pair : graph_pools) {
1900*da0073e9SAndroid Build Coastguard Worker pool_to_id[pair.second.get()] = pair.first;
1901*da0073e9SAndroid Build Coastguard Worker }
1902*da0073e9SAndroid Build Coastguard Worker for (const auto& pair : graph_pools_freeable) {
1903*da0073e9SAndroid Build Coastguard Worker pool_to_id[pair.second] = pair.first;
1904*da0073e9SAndroid Build Coastguard Worker }
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker size_t total_active = 0;
1907*da0073e9SAndroid Build Coastguard Worker std::vector<SegmentInfo> result;
1908*da0073e9SAndroid Build Coastguard Worker const auto all_blocks = get_all_blocks();
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker for (const Block* const head_block : all_blocks) {
1911*da0073e9SAndroid Build Coastguard Worker // For expandable segments, we report one segment for each contiguous
1912*da0073e9SAndroid Build Coastguard Worker // mapped range of memory
1913*da0073e9SAndroid Build Coastguard Worker if (head_block->prev && head_block->prev->mapped) {
1914*da0073e9SAndroid Build Coastguard Worker continue;
1915*da0073e9SAndroid Build Coastguard Worker }
1916*da0073e9SAndroid Build Coastguard Worker result.emplace_back();
1917*da0073e9SAndroid Build Coastguard Worker SegmentInfo& segment_info = result.back();
1918*da0073e9SAndroid Build Coastguard Worker segment_info.device = head_block->device;
1919*da0073e9SAndroid Build Coastguard Worker segment_info.address = reinterpret_cast<size_t>(head_block->ptr);
1920*da0073e9SAndroid Build Coastguard Worker segment_info.stream = head_block->stream;
1921*da0073e9SAndroid Build Coastguard Worker segment_info.is_large = (!head_block->pool->is_small);
1922*da0073e9SAndroid Build Coastguard Worker segment_info.is_expandable = head_block->expandable_segment_;
1923*da0073e9SAndroid Build Coastguard Worker segment_info.context_when_allocated =
1924*da0073e9SAndroid Build Coastguard Worker head_block->context_when_segment_allocated;
1925*da0073e9SAndroid Build Coastguard Worker auto mempool_id = pool_to_id.find(head_block->pool->owner_PrivatePool);
1926*da0073e9SAndroid Build Coastguard Worker if (mempool_id != pool_to_id.end()) {
1927*da0073e9SAndroid Build Coastguard Worker segment_info.owner_private_pool_id = mempool_id->second;
1928*da0073e9SAndroid Build Coastguard Worker }
1929*da0073e9SAndroid Build Coastguard Worker
1930*da0073e9SAndroid Build Coastguard Worker const Block* block = head_block;
1931*da0073e9SAndroid Build Coastguard Worker while (block != nullptr && block->mapped) {
1932*da0073e9SAndroid Build Coastguard Worker segment_info.blocks.emplace_back();
1933*da0073e9SAndroid Build Coastguard Worker BlockInfo& block_info = segment_info.blocks.back();
1934*da0073e9SAndroid Build Coastguard Worker
1935*da0073e9SAndroid Build Coastguard Worker block_info.size = block->size;
1936*da0073e9SAndroid Build Coastguard Worker block_info.requested_size = block->requested_size;
1937*da0073e9SAndroid Build Coastguard Worker block_info.allocated = block->allocated;
1938*da0073e9SAndroid Build Coastguard Worker block_info.active = block->allocated || (block->event_count > 0) ||
1939*da0073e9SAndroid Build Coastguard Worker !block->stream_uses.empty();
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker segment_info.total_size += block_info.size;
1942*da0073e9SAndroid Build Coastguard Worker if (block_info.allocated) {
1943*da0073e9SAndroid Build Coastguard Worker segment_info.allocated_size += block_info.size;
1944*da0073e9SAndroid Build Coastguard Worker }
1945*da0073e9SAndroid Build Coastguard Worker if (block_info.active) {
1946*da0073e9SAndroid Build Coastguard Worker segment_info.active_size += block_info.size;
1947*da0073e9SAndroid Build Coastguard Worker segment_info.requested_size += block_info.requested_size;
1948*da0073e9SAndroid Build Coastguard Worker }
1949*da0073e9SAndroid Build Coastguard Worker block_info.context_when_allocated = block->context_when_allocated;
1950*da0073e9SAndroid Build Coastguard Worker block = block->next;
1951*da0073e9SAndroid Build Coastguard Worker }
1952*da0073e9SAndroid Build Coastguard Worker total_active += segment_info.active_size;
1953*da0073e9SAndroid Build Coastguard Worker }
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker std::sort(
1956*da0073e9SAndroid Build Coastguard Worker result.begin(),
1957*da0073e9SAndroid Build Coastguard Worker result.end(),
1958*da0073e9SAndroid Build Coastguard Worker [](const SegmentInfo& a, const SegmentInfo& b) {
1959*da0073e9SAndroid Build Coastguard Worker return a.address < b.address;
1960*da0073e9SAndroid Build Coastguard Worker });
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker record_trace(TraceEntry::SNAPSHOT, 0, total_active, nullptr, 0, nullptr);
1963*da0073e9SAndroid Build Coastguard Worker return result;
1964*da0073e9SAndroid Build Coastguard Worker }
1965*da0073e9SAndroid Build Coastguard Worker
trace(const std::function<time_t (approx_time_t)> & tsc_to_us)1966*da0073e9SAndroid Build Coastguard Worker std::vector<TraceEntry> trace(
1967*da0073e9SAndroid Build Coastguard Worker const std::function<time_t(approx_time_t)>& tsc_to_us) {
1968*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
1969*da0073e9SAndroid Build Coastguard Worker std::vector<TraceEntry> result;
1970*da0073e9SAndroid Build Coastguard Worker alloc_buffer.getEntries(result);
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker // Convert all the timestamps from tsc to epoch time in microseconds.
1973*da0073e9SAndroid Build Coastguard Worker for (auto& te : result) {
1974*da0073e9SAndroid Build Coastguard Worker te.time_.t_ = tsc_to_us(te.time_.approx_t_);
1975*da0073e9SAndroid Build Coastguard Worker }
1976*da0073e9SAndroid Build Coastguard Worker return result;
1977*da0073e9SAndroid Build Coastguard Worker }
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker // This function takes the size and number of divisions argument and rounds
1980*da0073e9SAndroid Build Coastguard Worker // up the size argument for the nearest power-of-2 division.
1981*da0073e9SAndroid Build Coastguard Worker // For example, if we need to round-up 1200 and number of divisions is 4,
1982*da0073e9SAndroid Build Coastguard Worker // the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
1983*da0073e9SAndroid Build Coastguard Worker // them, the values are 1024, 1280, 1536, and 1792. So the function will
1984*da0073e9SAndroid Build Coastguard Worker // return 1280 as the nearest ceiling of power-2 divison.
roundup_power2_next_division(size_t size,size_t divisions)1985*da0073e9SAndroid Build Coastguard Worker static size_t roundup_power2_next_division(size_t size, size_t divisions) {
1986*da0073e9SAndroid Build Coastguard Worker if (llvm::isPowerOf2_64(size)) {
1987*da0073e9SAndroid Build Coastguard Worker return size;
1988*da0073e9SAndroid Build Coastguard Worker }
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(divisions >= 2, "Only 2 or more divisions are supported");
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard Worker // divide the space between these 2's power into equal divisions
1993*da0073e9SAndroid Build Coastguard Worker // If division is zero, return the power-of-2 ceiling.
1994*da0073e9SAndroid Build Coastguard Worker size_t power2_floor = llvm::PowerOf2Floor(size);
1995*da0073e9SAndroid Build Coastguard Worker size_t power2_divison =
1996*da0073e9SAndroid Build Coastguard Worker power2_floor >> (63 - llvm::countLeadingZeros(divisions));
1997*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(power2_divison == 0)) {
1998*da0073e9SAndroid Build Coastguard Worker return (power2_floor << 1);
1999*da0073e9SAndroid Build Coastguard Worker }
2000*da0073e9SAndroid Build Coastguard Worker size_t round_size_floor = size & (~(power2_divison - 1));
2001*da0073e9SAndroid Build Coastguard Worker return (round_size_floor == size) ? size
2002*da0073e9SAndroid Build Coastguard Worker : round_size_floor + power2_divison;
2003*da0073e9SAndroid Build Coastguard Worker }
2004*da0073e9SAndroid Build Coastguard Worker
round_size(size_t size)2005*da0073e9SAndroid Build Coastguard Worker static size_t round_size(size_t size) {
2006*da0073e9SAndroid Build Coastguard Worker if (size < kMinBlockSize) {
2007*da0073e9SAndroid Build Coastguard Worker return kMinBlockSize;
2008*da0073e9SAndroid Build Coastguard Worker } else {
2009*da0073e9SAndroid Build Coastguard Worker auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
2010*da0073e9SAndroid Build Coastguard Worker if (divisions > 1 && size > (kMinBlockSize * divisions)) {
2011*da0073e9SAndroid Build Coastguard Worker return roundup_power2_next_division(size, divisions);
2012*da0073e9SAndroid Build Coastguard Worker } else {
2013*da0073e9SAndroid Build Coastguard Worker return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
2014*da0073e9SAndroid Build Coastguard Worker }
2015*da0073e9SAndroid Build Coastguard Worker }
2016*da0073e9SAndroid Build Coastguard Worker }
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker // See Note [Interaction with CUDA graph capture]
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker // Called by CUDAGraph::capture_begin
beginAllocateToPool(MempoolId_t mempool_id,std::function<bool (cudaStream_t)> filter)2021*da0073e9SAndroid Build Coastguard Worker void beginAllocateToPool(
2022*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
2023*da0073e9SAndroid Build Coastguard Worker std::function<bool(cudaStream_t)> filter) {
2024*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
2025*da0073e9SAndroid Build Coastguard Worker auto it = graph_pools.find(mempool_id);
2026*da0073e9SAndroid Build Coastguard Worker if (it == graph_pools.end()) {
2027*da0073e9SAndroid Build Coastguard Worker // mempool_id does not reference an existing pool. Make a new pool for
2028*da0073e9SAndroid Build Coastguard Worker // this capture.
2029*da0073e9SAndroid Build Coastguard Worker graph_pools.emplace(mempool_id, std::make_unique<PrivatePool>());
2030*da0073e9SAndroid Build Coastguard Worker } else {
2031*da0073e9SAndroid Build Coastguard Worker // mempool_id references an existing pool, which the current capture will
2032*da0073e9SAndroid Build Coastguard Worker // share. Check this pool is live (at least one other capture already
2033*da0073e9SAndroid Build Coastguard Worker // references it).
2034*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
2035*da0073e9SAndroid Build Coastguard Worker it->second->use_count++;
2036*da0073e9SAndroid Build Coastguard Worker }
2037*da0073e9SAndroid Build Coastguard Worker for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
2038*da0073e9SAndroid Build Coastguard Worker ++it2) {
2039*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
2040*da0073e9SAndroid Build Coastguard Worker it2->first != mempool_id,
2041*da0073e9SAndroid Build Coastguard Worker "beginAllocateToPool: already recording to mempool_id");
2042*da0073e9SAndroid Build Coastguard Worker }
2043*da0073e9SAndroid Build Coastguard Worker captures_underway.emplace_back(mempool_id, std::move(filter));
2044*da0073e9SAndroid Build Coastguard Worker }
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker // Called by CUDAGraph::capture_end
endAllocateToPool(MempoolId_t mempool_id)2047*da0073e9SAndroid Build Coastguard Worker void endAllocateToPool(MempoolId_t mempool_id) {
2048*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
2049*da0073e9SAndroid Build Coastguard Worker for (auto it = captures_underway.begin(); it != captures_underway.end();
2050*da0073e9SAndroid Build Coastguard Worker ++it) {
2051*da0073e9SAndroid Build Coastguard Worker if (it->first == mempool_id) {
2052*da0073e9SAndroid Build Coastguard Worker captures_underway.erase(it);
2053*da0073e9SAndroid Build Coastguard Worker return;
2054*da0073e9SAndroid Build Coastguard Worker }
2055*da0073e9SAndroid Build Coastguard Worker }
2056*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
2057*da0073e9SAndroid Build Coastguard Worker false, "endAllocatePool: not currently recording to mempool_id");
2058*da0073e9SAndroid Build Coastguard Worker }
2059*da0073e9SAndroid Build Coastguard Worker
2060*da0073e9SAndroid Build Coastguard Worker // Called by CUDAGraph::reset
releasePool(MempoolId_t mempool_id)2061*da0073e9SAndroid Build Coastguard Worker void releasePool(MempoolId_t mempool_id) {
2062*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
2063*da0073e9SAndroid Build Coastguard Worker // The instantiated cudaGraphExec_t has been destroyed. We can't blindly
2064*da0073e9SAndroid Build Coastguard Worker // delete and cudaFree the mempool its capture used, because
2065*da0073e9SAndroid Build Coastguard Worker // 1. other graph(s) might share the same pool
2066*da0073e9SAndroid Build Coastguard Worker // 2. the user might still hold references to output tensors allocated
2067*da0073e9SAndroid Build Coastguard Worker // during capture.
2068*da0073e9SAndroid Build Coastguard Worker // To handle 1 and 2, we track the number of graphs using this particular
2069*da0073e9SAndroid Build Coastguard Worker // mempool. When the count reaches 0, we tell free_cached_blocks it may now
2070*da0073e9SAndroid Build Coastguard Worker // cudaFree blocks from this graph's pool when it discovers they're unused
2071*da0073e9SAndroid Build Coastguard Worker // (unsplit).
2072*da0073e9SAndroid Build Coastguard Worker auto it = graph_pools.find(mempool_id);
2073*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(it != graph_pools.end());
2074*da0073e9SAndroid Build Coastguard Worker auto uc = --(it->second->use_count);
2075*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(uc >= 0);
2076*da0073e9SAndroid Build Coastguard Worker if (uc == 0) {
2077*da0073e9SAndroid Build Coastguard Worker // Allows free_cached_blocks to begin cudaFreeing this pool's memory,
2078*da0073e9SAndroid Build Coastguard Worker // and makes sure this pool wasn't somehow made freeable already.
2079*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2080*da0073e9SAndroid Build Coastguard Worker bool inserted =
2081*da0073e9SAndroid Build Coastguard Worker graph_pools_freeable.insert({mempool_id, it->second.get()}).second;
2082*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(inserted);
2083*da0073e9SAndroid Build Coastguard Worker }
2084*da0073e9SAndroid Build Coastguard Worker }
2085*da0073e9SAndroid Build Coastguard Worker
addPeerAccess(c10::DeviceIndex dev_to_access)2086*da0073e9SAndroid Build Coastguard Worker void addPeerAccess(c10::DeviceIndex dev_to_access) {
2087*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
2088*da0073e9SAndroid Build Coastguard Worker if (std::find(
2089*da0073e9SAndroid Build Coastguard Worker devices_with_peer_access_.begin(),
2090*da0073e9SAndroid Build Coastguard Worker devices_with_peer_access_.end(),
2091*da0073e9SAndroid Build Coastguard Worker dev_to_access) != devices_with_peer_access_.end()) {
2092*da0073e9SAndroid Build Coastguard Worker return;
2093*da0073e9SAndroid Build Coastguard Worker }
2094*da0073e9SAndroid Build Coastguard Worker devices_with_peer_access_.push_back(dev_to_access);
2095*da0073e9SAndroid Build Coastguard Worker for (auto& es : expandable_segments_) {
2096*da0073e9SAndroid Build Coastguard Worker es->addPeer(dev_to_access);
2097*da0073e9SAndroid Build Coastguard Worker }
2098*da0073e9SAndroid Build Coastguard Worker }
peers() const2099*da0073e9SAndroid Build Coastguard Worker std::vector<c10::DeviceIndex> peers() const {
2100*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::recursive_mutex> lock(mutex);
2101*da0073e9SAndroid Build Coastguard Worker return devices_with_peer_access_;
2102*da0073e9SAndroid Build Coastguard Worker }
2103*da0073e9SAndroid Build Coastguard Worker
hasAllocatedExpandableSegments() const2104*da0073e9SAndroid Build Coastguard Worker bool hasAllocatedExpandableSegments() const {
2105*da0073e9SAndroid Build Coastguard Worker return !expandable_segments_.empty();
2106*da0073e9SAndroid Build Coastguard Worker }
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker private:
2109*da0073e9SAndroid Build Coastguard Worker // All private methods do not acquire the allocator mutex.
2110*da0073e9SAndroid Build Coastguard Worker
get_all_blocks() const2111*da0073e9SAndroid Build Coastguard Worker std::vector<const Block*> get_all_blocks() const {
2112*da0073e9SAndroid Build Coastguard Worker std::vector<const Block*> blocks;
2113*da0073e9SAndroid Build Coastguard Worker blocks.insert(
2114*da0073e9SAndroid Build Coastguard Worker blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end());
2115*da0073e9SAndroid Build Coastguard Worker blocks.insert(
2116*da0073e9SAndroid Build Coastguard Worker blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end());
2117*da0073e9SAndroid Build Coastguard Worker for (const auto& gp : graph_pools) {
2118*da0073e9SAndroid Build Coastguard Worker blocks.insert(
2119*da0073e9SAndroid Build Coastguard Worker blocks.end(),
2120*da0073e9SAndroid Build Coastguard Worker gp.second->small_blocks.blocks.begin(),
2121*da0073e9SAndroid Build Coastguard Worker gp.second->small_blocks.blocks.end());
2122*da0073e9SAndroid Build Coastguard Worker blocks.insert(
2123*da0073e9SAndroid Build Coastguard Worker blocks.end(),
2124*da0073e9SAndroid Build Coastguard Worker gp.second->large_blocks.blocks.begin(),
2125*da0073e9SAndroid Build Coastguard Worker gp.second->large_blocks.blocks.end());
2126*da0073e9SAndroid Build Coastguard Worker }
2127*da0073e9SAndroid Build Coastguard Worker blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end());
2128*da0073e9SAndroid Build Coastguard Worker return blocks;
2129*da0073e9SAndroid Build Coastguard Worker }
2130*da0073e9SAndroid Build Coastguard Worker
get_private_pool_head_blocks(PrivatePool * pool) const2131*da0073e9SAndroid Build Coastguard Worker std::vector<Block*> get_private_pool_head_blocks(PrivatePool* pool) const {
2132*da0073e9SAndroid Build Coastguard Worker std::vector<Block*> blocks;
2133*da0073e9SAndroid Build Coastguard Worker for (Block* b : active_blocks) {
2134*da0073e9SAndroid Build Coastguard Worker if ((b->pool == &pool->small_blocks || b->pool == &pool->large_blocks) &&
2135*da0073e9SAndroid Build Coastguard Worker b->prev == nullptr) {
2136*da0073e9SAndroid Build Coastguard Worker blocks.push_back(b);
2137*da0073e9SAndroid Build Coastguard Worker }
2138*da0073e9SAndroid Build Coastguard Worker }
2139*da0073e9SAndroid Build Coastguard Worker
2140*da0073e9SAndroid Build Coastguard Worker for (Block* b : pool->small_blocks.blocks) {
2141*da0073e9SAndroid Build Coastguard Worker if (b->prev == nullptr) {
2142*da0073e9SAndroid Build Coastguard Worker blocks.push_back(b);
2143*da0073e9SAndroid Build Coastguard Worker }
2144*da0073e9SAndroid Build Coastguard Worker }
2145*da0073e9SAndroid Build Coastguard Worker for (Block* b : pool->large_blocks.blocks) {
2146*da0073e9SAndroid Build Coastguard Worker if (b->prev == nullptr) {
2147*da0073e9SAndroid Build Coastguard Worker blocks.push_back(b);
2148*da0073e9SAndroid Build Coastguard Worker }
2149*da0073e9SAndroid Build Coastguard Worker }
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker return blocks;
2152*da0073e9SAndroid Build Coastguard Worker }
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker // returns the smallest possible address in any segment
2155*da0073e9SAndroid Build Coastguard Worker // where there is enough free address space to fit size
2156*da0073e9SAndroid Build Coastguard Worker // may be composed of free and unmapped segments
find_expandable_block(c10::DeviceIndex device,cudaStream_t stream,BlockPool * pool,size_t size)2157*da0073e9SAndroid Build Coastguard Worker Block* find_expandable_block(
2158*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
2159*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
2160*da0073e9SAndroid Build Coastguard Worker BlockPool* pool,
2161*da0073e9SAndroid Build Coastguard Worker size_t size) {
2162*da0073e9SAndroid Build Coastguard Worker Block key(device, stream, 0);
2163*da0073e9SAndroid Build Coastguard Worker
2164*da0073e9SAndroid Build Coastguard Worker auto allocatable = [](Block* b) {
2165*da0073e9SAndroid Build Coastguard Worker return b && !b->allocated && b->event_count == 0 &&
2166*da0073e9SAndroid Build Coastguard Worker b->stream_uses.empty();
2167*da0073e9SAndroid Build Coastguard Worker };
2168*da0073e9SAndroid Build Coastguard Worker auto has_available_address_space = [&](Block* b) {
2169*da0073e9SAndroid Build Coastguard Worker size_t bytes = 0;
2170*da0073e9SAndroid Build Coastguard Worker while (bytes < size && allocatable(b)) {
2171*da0073e9SAndroid Build Coastguard Worker bytes += b->size;
2172*da0073e9SAndroid Build Coastguard Worker b = b->next;
2173*da0073e9SAndroid Build Coastguard Worker }
2174*da0073e9SAndroid Build Coastguard Worker return bytes >= size;
2175*da0073e9SAndroid Build Coastguard Worker };
2176*da0073e9SAndroid Build Coastguard Worker for (auto it = pool->unmapped.lower_bound(&key);
2177*da0073e9SAndroid Build Coastguard Worker it != pool->unmapped.end() && (*it)->stream == stream;
2178*da0073e9SAndroid Build Coastguard Worker ++it) {
2179*da0073e9SAndroid Build Coastguard Worker Block* c = *it;
2180*da0073e9SAndroid Build Coastguard Worker // we found the lowest address of an unmapped segment
2181*da0073e9SAndroid Build Coastguard Worker // but there might be a free segment we can also use
2182*da0073e9SAndroid Build Coastguard Worker // right before it
2183*da0073e9SAndroid Build Coastguard Worker if (allocatable(c->prev)) {
2184*da0073e9SAndroid Build Coastguard Worker c = c->prev;
2185*da0073e9SAndroid Build Coastguard Worker }
2186*da0073e9SAndroid Build Coastguard Worker if (has_available_address_space(c)) {
2187*da0073e9SAndroid Build Coastguard Worker return c;
2188*da0073e9SAndroid Build Coastguard Worker }
2189*da0073e9SAndroid Build Coastguard Worker }
2190*da0073e9SAndroid Build Coastguard Worker auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
2191*da0073e9SAndroid Build Coastguard Worker cudaDeviceProp prop{};
2192*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
2193*da0073e9SAndroid Build Coastguard Worker // we allocate enough address space for 1 1/8 the total memory on the GPU.
2194*da0073e9SAndroid Build Coastguard Worker // This allows for some cases where we have to unmap pages earlier in the
2195*da0073e9SAndroid Build Coastguard Worker // segment to put them at the end.
2196*da0073e9SAndroid Build Coastguard Worker size_t address_space_size = prop.totalGlobalMem + prop.totalGlobalMem / 8;
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker expandable_segments_.emplace_back(new ExpandableSegment(
2199*da0073e9SAndroid Build Coastguard Worker device,
2200*da0073e9SAndroid Build Coastguard Worker stream,
2201*da0073e9SAndroid Build Coastguard Worker address_space_size,
2202*da0073e9SAndroid Build Coastguard Worker segment_size,
2203*da0073e9SAndroid Build Coastguard Worker devices_with_peer_access_));
2204*da0073e9SAndroid Build Coastguard Worker
2205*da0073e9SAndroid Build Coastguard Worker ExpandableSegment* es = expandable_segments_.back();
2206*da0073e9SAndroid Build Coastguard Worker Block* candidate = new Block(device, stream, es->size(), pool, es->ptr());
2207*da0073e9SAndroid Build Coastguard Worker candidate->mapped = false;
2208*da0073e9SAndroid Build Coastguard Worker candidate->expandable_segment_ = es;
2209*da0073e9SAndroid Build Coastguard Worker pool->unmapped.insert(candidate);
2210*da0073e9SAndroid Build Coastguard Worker return candidate;
2211*da0073e9SAndroid Build Coastguard Worker }
2212*da0073e9SAndroid Build Coastguard Worker
map_block(Block * to_map,size_t size,const std::shared_ptr<GatheredContext> & ctx)2213*da0073e9SAndroid Build Coastguard Worker bool map_block(
2214*da0073e9SAndroid Build Coastguard Worker Block* to_map,
2215*da0073e9SAndroid Build Coastguard Worker size_t size,
2216*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& ctx) {
2217*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
2218*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
2219*da0073e9SAndroid Build Coastguard Worker !to_map->context_when_allocated); // unmapped blocks should not keep
2220*da0073e9SAndroid Build Coastguard Worker // history
2221*da0073e9SAndroid Build Coastguard Worker auto mapped_range =
2222*da0073e9SAndroid Build Coastguard Worker to_map->expandable_segment_->map(SegmentRange{to_map->ptr, size});
2223*da0073e9SAndroid Build Coastguard Worker // failed to map the memory
2224*da0073e9SAndroid Build Coastguard Worker if (mapped_range.size == 0) {
2225*da0073e9SAndroid Build Coastguard Worker return false;
2226*da0073e9SAndroid Build Coastguard Worker }
2227*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
2228*da0073e9SAndroid Build Coastguard Worker mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
2229*da0073e9SAndroid Build Coastguard Worker
2230*da0073e9SAndroid Build Coastguard Worker BlockPool& pool = *to_map->pool;
2231*da0073e9SAndroid Build Coastguard Worker pool.unmapped.erase(to_map);
2232*da0073e9SAndroid Build Coastguard Worker to_map->mapped = true;
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker if (mapped_range.size < to_map->size) {
2235*da0073e9SAndroid Build Coastguard Worker // to_map -> remaining -> to_map->next(?)
2236*da0073e9SAndroid Build Coastguard Worker Block* remaining = new Block(
2237*da0073e9SAndroid Build Coastguard Worker to_map->device,
2238*da0073e9SAndroid Build Coastguard Worker to_map->stream,
2239*da0073e9SAndroid Build Coastguard Worker to_map->size - mapped_range.size,
2240*da0073e9SAndroid Build Coastguard Worker &pool,
2241*da0073e9SAndroid Build Coastguard Worker static_cast<char*>(to_map->ptr) + mapped_range.size);
2242*da0073e9SAndroid Build Coastguard Worker remaining->mapped = false;
2243*da0073e9SAndroid Build Coastguard Worker remaining->expandable_segment_ = to_map->expandable_segment_;
2244*da0073e9SAndroid Build Coastguard Worker remaining->splice(to_map, to_map->next);
2245*da0073e9SAndroid Build Coastguard Worker pool.unmapped.insert(remaining);
2246*da0073e9SAndroid Build Coastguard Worker to_map->size = mapped_range.size;
2247*da0073e9SAndroid Build Coastguard Worker }
2248*da0073e9SAndroid Build Coastguard Worker
2249*da0073e9SAndroid Build Coastguard Worker try_merge_blocks(to_map, to_map->prev, pool);
2250*da0073e9SAndroid Build Coastguard Worker try_merge_blocks(to_map, to_map->next, pool);
2251*da0073e9SAndroid Build Coastguard Worker
2252*da0073e9SAndroid Build Coastguard Worker pool.insert_into_blocks(to_map);
2253*da0073e9SAndroid Build Coastguard Worker
2254*da0073e9SAndroid Build Coastguard Worker // update statistics
2255*da0073e9SAndroid Build Coastguard Worker total_allocated_memory += mapped_range.size;
2256*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
2257*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
2258*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[stat_type].increase(mapped_range.size);
2259*da0073e9SAndroid Build Coastguard Worker });
2260*da0073e9SAndroid Build Coastguard Worker auto reserved_bytes_gauge =
2261*da0073e9SAndroid Build Coastguard Worker STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
2262*da0073e9SAndroid Build Coastguard Worker reserved_bytes_gauge.record(
2263*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
2264*da0073e9SAndroid Build Coastguard Worker .current);
2265*da0073e9SAndroid Build Coastguard Worker
2266*da0073e9SAndroid Build Coastguard Worker stats.num_device_alloc++;
2267*da0073e9SAndroid Build Coastguard Worker record_trace(
2268*da0073e9SAndroid Build Coastguard Worker TraceEntry::SEGMENT_MAP,
2269*da0073e9SAndroid Build Coastguard Worker int64_t(mapped_range.ptr),
2270*da0073e9SAndroid Build Coastguard Worker mapped_range.size,
2271*da0073e9SAndroid Build Coastguard Worker to_map->stream,
2272*da0073e9SAndroid Build Coastguard Worker to_map->device,
2273*da0073e9SAndroid Build Coastguard Worker ctx);
2274*da0073e9SAndroid Build Coastguard Worker if (!to_map->prev && !to_map->context_when_segment_allocated) {
2275*da0073e9SAndroid Build Coastguard Worker to_map->context_when_segment_allocated = ctx;
2276*da0073e9SAndroid Build Coastguard Worker }
2277*da0073e9SAndroid Build Coastguard Worker
2278*da0073e9SAndroid Build Coastguard Worker return true;
2279*da0073e9SAndroid Build Coastguard Worker }
2280*da0073e9SAndroid Build Coastguard Worker
try_allocate_expandable_block(c10::DeviceIndex device,cudaStream_t stream,BlockPool * pool,size_t size,const std::shared_ptr<GatheredContext> & ctx)2281*da0073e9SAndroid Build Coastguard Worker Block* try_allocate_expandable_block(
2282*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
2283*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
2284*da0073e9SAndroid Build Coastguard Worker BlockPool* pool,
2285*da0073e9SAndroid Build Coastguard Worker size_t size,
2286*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& ctx) {
2287*da0073e9SAndroid Build Coastguard Worker Block* candidate = find_expandable_block(device, stream, pool, size);
2288*da0073e9SAndroid Build Coastguard Worker // Candidate is now a list free/unmapped blocks with at least size room:
2289*da0073e9SAndroid Build Coastguard Worker // unmapped -> null
2290*da0073e9SAndroid Build Coastguard Worker // unmapped -> free -> *
2291*da0073e9SAndroid Build Coastguard Worker // free -> unmapped -> *
2292*da0073e9SAndroid Build Coastguard Worker
2293*da0073e9SAndroid Build Coastguard Worker if (!candidate->mapped &&
2294*da0073e9SAndroid Build Coastguard Worker !map_block(candidate, std::min(candidate->size, size), ctx)) {
2295*da0073e9SAndroid Build Coastguard Worker return nullptr;
2296*da0073e9SAndroid Build Coastguard Worker }
2297*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(candidate->mapped);
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker while (candidate->size < size) {
2300*da0073e9SAndroid Build Coastguard Worker // invariant: free -> unmapped -> *
2301*da0073e9SAndroid Build Coastguard Worker // map_block will map some of unmapped and merge with free
2302*da0073e9SAndroid Build Coastguard Worker auto remaining = size - candidate->size;
2303*da0073e9SAndroid Build Coastguard Worker auto new_candidate = candidate->next;
2304*da0073e9SAndroid Build Coastguard Worker if (!map_block(
2305*da0073e9SAndroid Build Coastguard Worker new_candidate, std::min(remaining, candidate->next->size), ctx)) {
2306*da0073e9SAndroid Build Coastguard Worker return nullptr;
2307*da0073e9SAndroid Build Coastguard Worker }
2308*da0073e9SAndroid Build Coastguard Worker candidate = new_candidate;
2309*da0073e9SAndroid Build Coastguard Worker }
2310*da0073e9SAndroid Build Coastguard Worker pool->blocks.erase(candidate);
2311*da0073e9SAndroid Build Coastguard Worker return candidate;
2312*da0073e9SAndroid Build Coastguard Worker }
2313*da0073e9SAndroid Build Coastguard Worker
2314*da0073e9SAndroid Build Coastguard Worker /** moves a block into a pool of cached free blocks */
free_block(Block * block,const std::shared_ptr<GatheredContext> & context)2315*da0073e9SAndroid Build Coastguard Worker void free_block(
2316*da0073e9SAndroid Build Coastguard Worker Block* block,
2317*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
2318*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
2319*da0073e9SAndroid Build Coastguard Worker !block->allocated && block->event_count == 0 &&
2320*da0073e9SAndroid Build Coastguard Worker block->stream_uses.empty());
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker record_trace(
2323*da0073e9SAndroid Build Coastguard Worker TraceEntry::FREE_COMPLETED,
2324*da0073e9SAndroid Build Coastguard Worker int64_t(block->ptr),
2325*da0073e9SAndroid Build Coastguard Worker block->requested_size,
2326*da0073e9SAndroid Build Coastguard Worker block->stream,
2327*da0073e9SAndroid Build Coastguard Worker block->device,
2328*da0073e9SAndroid Build Coastguard Worker context ? context : block->context_when_allocated);
2329*da0073e9SAndroid Build Coastguard Worker
2330*da0073e9SAndroid Build Coastguard Worker block->context_when_allocated = nullptr;
2331*da0073e9SAndroid Build Coastguard Worker size_t original_block_size = block->size;
2332*da0073e9SAndroid Build Coastguard Worker size_t requested_size = block->requested_size;
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker auto& pool = *block->pool;
2335*da0073e9SAndroid Build Coastguard Worker int64_t net_change_inactive_split_blocks = 0;
2336*da0073e9SAndroid Build Coastguard Worker int64_t net_change_inactive_split_size = 0;
2337*da0073e9SAndroid Build Coastguard Worker
2338*da0073e9SAndroid Build Coastguard Worker const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
2339*da0073e9SAndroid Build Coastguard Worker for (Block* merge_candidate : merge_candidates) {
2340*da0073e9SAndroid Build Coastguard Worker const auto subsumed_size = try_merge_blocks(block, merge_candidate, pool);
2341*da0073e9SAndroid Build Coastguard Worker if (subsumed_size > 0) {
2342*da0073e9SAndroid Build Coastguard Worker net_change_inactive_split_blocks -= 1;
2343*da0073e9SAndroid Build Coastguard Worker net_change_inactive_split_size -= static_cast<int64_t>(subsumed_size);
2344*da0073e9SAndroid Build Coastguard Worker }
2345*da0073e9SAndroid Build Coastguard Worker }
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker active_blocks.erase(block);
2348*da0073e9SAndroid Build Coastguard Worker // Makes sure the Block* isn't already present in the pool we're freeing it
2349*da0073e9SAndroid Build Coastguard Worker // back into.
2350*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2351*da0073e9SAndroid Build Coastguard Worker bool inserted = pool.insert_into_blocks(block).second;
2352*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(inserted);
2353*da0073e9SAndroid Build Coastguard Worker
2354*da0073e9SAndroid Build Coastguard Worker if (block->is_split()) {
2355*da0073e9SAndroid Build Coastguard Worker net_change_inactive_split_blocks += 1;
2356*da0073e9SAndroid Build Coastguard Worker net_change_inactive_split_size += static_cast<int64_t>(block->size);
2357*da0073e9SAndroid Build Coastguard Worker }
2358*da0073e9SAndroid Build Coastguard Worker
2359*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(pool);
2360*da0073e9SAndroid Build Coastguard Worker
2361*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
2362*da0073e9SAndroid Build Coastguard Worker // inactive_split tries to capture the idea that blocks
2363*da0073e9SAndroid Build Coastguard Worker // cannot be freed when requested, but fully free pages
2364*da0073e9SAndroid Build Coastguard Worker // of expandable blocks can always be freed.
2365*da0073e9SAndroid Build Coastguard Worker // The logic to track this as statistic is pretty involved,
2366*da0073e9SAndroid Build Coastguard Worker // so we simply just exclude expandable segments from
2367*da0073e9SAndroid Build Coastguard Worker // inactive_split
2368*da0073e9SAndroid Build Coastguard Worker if (!block->expandable_segment_) {
2369*da0073e9SAndroid Build Coastguard Worker if (net_change_inactive_split_blocks > 0) {
2370*da0073e9SAndroid Build Coastguard Worker stats.inactive_split[stat_type].increase(
2371*da0073e9SAndroid Build Coastguard Worker static_cast<size_t>(net_change_inactive_split_blocks));
2372*da0073e9SAndroid Build Coastguard Worker } else if (net_change_inactive_split_blocks < 0) {
2373*da0073e9SAndroid Build Coastguard Worker stats.inactive_split[stat_type].decrease(
2374*da0073e9SAndroid Build Coastguard Worker static_cast<size_t>(-net_change_inactive_split_blocks));
2375*da0073e9SAndroid Build Coastguard Worker }
2376*da0073e9SAndroid Build Coastguard Worker if (net_change_inactive_split_size > 0) {
2377*da0073e9SAndroid Build Coastguard Worker stats.inactive_split_bytes[stat_type].increase(
2378*da0073e9SAndroid Build Coastguard Worker static_cast<size_t>(net_change_inactive_split_size));
2379*da0073e9SAndroid Build Coastguard Worker } else if (net_change_inactive_split_size < 0) {
2380*da0073e9SAndroid Build Coastguard Worker stats.inactive_split_bytes[stat_type].decrease(
2381*da0073e9SAndroid Build Coastguard Worker static_cast<size_t>(-net_change_inactive_split_size));
2382*da0073e9SAndroid Build Coastguard Worker }
2383*da0073e9SAndroid Build Coastguard Worker }
2384*da0073e9SAndroid Build Coastguard Worker stats.active[stat_type].decrease(1);
2385*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[stat_type].decrease(original_block_size);
2386*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[stat_type].decrease(requested_size);
2387*da0073e9SAndroid Build Coastguard Worker });
2388*da0073e9SAndroid Build Coastguard Worker }
2389*da0073e9SAndroid Build Coastguard Worker
2390*da0073e9SAndroid Build Coastguard Worker /** combine previously split blocks. returns the size of the subsumed block,
2391*da0073e9SAndroid Build Coastguard Worker * or 0 on failure. */
try_merge_blocks(Block * dst,Block * src,BlockPool & pool)2392*da0073e9SAndroid Build Coastguard Worker size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
2393*da0073e9SAndroid Build Coastguard Worker if (!src || src->allocated || src->event_count > 0 ||
2394*da0073e9SAndroid Build Coastguard Worker !src->stream_uses.empty() || dst->mapped != src->mapped) {
2395*da0073e9SAndroid Build Coastguard Worker return 0;
2396*da0073e9SAndroid Build Coastguard Worker }
2397*da0073e9SAndroid Build Coastguard Worker
2398*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(dst->is_split() && src->is_split());
2399*da0073e9SAndroid Build Coastguard Worker
2400*da0073e9SAndroid Build Coastguard Worker if (dst->prev == src) { // [src dst]
2401*da0073e9SAndroid Build Coastguard Worker dst->ptr = src->ptr;
2402*da0073e9SAndroid Build Coastguard Worker dst->prev = src->prev;
2403*da0073e9SAndroid Build Coastguard Worker if (dst->prev) {
2404*da0073e9SAndroid Build Coastguard Worker dst->prev->next = dst;
2405*da0073e9SAndroid Build Coastguard Worker }
2406*da0073e9SAndroid Build Coastguard Worker dst->context_when_segment_allocated =
2407*da0073e9SAndroid Build Coastguard Worker std::move(src->context_when_segment_allocated);
2408*da0073e9SAndroid Build Coastguard Worker } else { // [dest src]
2409*da0073e9SAndroid Build Coastguard Worker dst->next = src->next;
2410*da0073e9SAndroid Build Coastguard Worker if (dst->next) {
2411*da0073e9SAndroid Build Coastguard Worker dst->next->prev = dst;
2412*da0073e9SAndroid Build Coastguard Worker }
2413*da0073e9SAndroid Build Coastguard Worker }
2414*da0073e9SAndroid Build Coastguard Worker const size_t subsumed_size = src->size;
2415*da0073e9SAndroid Build Coastguard Worker dst->size += subsumed_size;
2416*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2417*da0073e9SAndroid Build Coastguard Worker auto erased =
2418*da0073e9SAndroid Build Coastguard Worker src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
2419*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
2420*da0073e9SAndroid Build Coastguard Worker delete src;
2421*da0073e9SAndroid Build Coastguard Worker
2422*da0073e9SAndroid Build Coastguard Worker return subsumed_size;
2423*da0073e9SAndroid Build Coastguard Worker }
2424*da0073e9SAndroid Build Coastguard Worker
get_pool(size_t size,cudaStream_t stream)2425*da0073e9SAndroid Build Coastguard Worker BlockPool& get_pool(size_t size, cudaStream_t stream) {
2426*da0073e9SAndroid Build Coastguard Worker // captures_underway is a conservative guess that the current stream may be
2427*da0073e9SAndroid Build Coastguard Worker // capturing. It's only non-empty if some thread has begun and not yet ended
2428*da0073e9SAndroid Build Coastguard Worker // a capture, so it's usually 0, and we can short-circuit
2429*da0073e9SAndroid Build Coastguard Worker // cudaStreamCaptureStatus (which does a TLS lookup).
2430*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(!captures_underway.empty())) {
2431*da0073e9SAndroid Build Coastguard Worker for (auto& entry : captures_underway) {
2432*da0073e9SAndroid Build Coastguard Worker if (entry.second(stream)) {
2433*da0073e9SAndroid Build Coastguard Worker auto it1 = graph_pools.find(entry.first);
2434*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
2435*da0073e9SAndroid Build Coastguard Worker if (size <= kSmallSize) {
2436*da0073e9SAndroid Build Coastguard Worker return it1->second->small_blocks;
2437*da0073e9SAndroid Build Coastguard Worker } else {
2438*da0073e9SAndroid Build Coastguard Worker return it1->second->large_blocks;
2439*da0073e9SAndroid Build Coastguard Worker }
2440*da0073e9SAndroid Build Coastguard Worker }
2441*da0073e9SAndroid Build Coastguard Worker }
2442*da0073e9SAndroid Build Coastguard Worker }
2443*da0073e9SAndroid Build Coastguard Worker if (size <= kSmallSize) {
2444*da0073e9SAndroid Build Coastguard Worker return small_blocks;
2445*da0073e9SAndroid Build Coastguard Worker } else {
2446*da0073e9SAndroid Build Coastguard Worker return large_blocks;
2447*da0073e9SAndroid Build Coastguard Worker }
2448*da0073e9SAndroid Build Coastguard Worker }
2449*da0073e9SAndroid Build Coastguard Worker
get_stat_types_for_pool(const BlockPool & pool)2450*da0073e9SAndroid Build Coastguard Worker StatTypes get_stat_types_for_pool(const BlockPool& pool) {
2451*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = {false};
2452*da0073e9SAndroid Build Coastguard Worker stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
2453*da0073e9SAndroid Build Coastguard Worker stat_types[static_cast<size_t>(
2454*da0073e9SAndroid Build Coastguard Worker pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
2455*da0073e9SAndroid Build Coastguard Worker return stat_types;
2456*da0073e9SAndroid Build Coastguard Worker }
2457*da0073e9SAndroid Build Coastguard Worker
should_split(const Block * block,size_t size)2458*da0073e9SAndroid Build Coastguard Worker bool should_split(const Block* block, size_t size) {
2459*da0073e9SAndroid Build Coastguard Worker size_t remaining = block->size - size;
2460*da0073e9SAndroid Build Coastguard Worker if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
2461*da0073e9SAndroid Build Coastguard Worker return remaining >= kMinBlockSize;
2462*da0073e9SAndroid Build Coastguard Worker } else {
2463*da0073e9SAndroid Build Coastguard Worker return (size < CUDAAllocatorConfig::max_split_size()) &&
2464*da0073e9SAndroid Build Coastguard Worker (remaining > kSmallSize);
2465*da0073e9SAndroid Build Coastguard Worker }
2466*da0073e9SAndroid Build Coastguard Worker }
2467*da0073e9SAndroid Build Coastguard Worker
get_allocation_size(size_t size)2468*da0073e9SAndroid Build Coastguard Worker static size_t get_allocation_size(size_t size) {
2469*da0073e9SAndroid Build Coastguard Worker if (size <= kSmallSize) {
2470*da0073e9SAndroid Build Coastguard Worker return kSmallBuffer;
2471*da0073e9SAndroid Build Coastguard Worker } else if (size < kMinLargeAlloc) {
2472*da0073e9SAndroid Build Coastguard Worker return kLargeBuffer;
2473*da0073e9SAndroid Build Coastguard Worker } else {
2474*da0073e9SAndroid Build Coastguard Worker return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
2475*da0073e9SAndroid Build Coastguard Worker }
2476*da0073e9SAndroid Build Coastguard Worker }
2477*da0073e9SAndroid Build Coastguard Worker
get_free_block(AllocParams & p)2478*da0073e9SAndroid Build Coastguard Worker bool get_free_block(AllocParams& p) {
2479*da0073e9SAndroid Build Coastguard Worker BlockPool& pool = *p.pool;
2480*da0073e9SAndroid Build Coastguard Worker
2481*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(
2482*da0073e9SAndroid Build Coastguard Worker set_fraction &&
2483*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
2484*da0073e9SAndroid Build Coastguard Worker // Track block reuse interval only when garbage collection is enabled.
2485*da0073e9SAndroid Build Coastguard Worker ++pool.get_free_blocks_call_count;
2486*da0073e9SAndroid Build Coastguard Worker }
2487*da0073e9SAndroid Build Coastguard Worker auto it = pool.blocks.lower_bound(&p.search_key);
2488*da0073e9SAndroid Build Coastguard Worker if (it == pool.blocks.end() || (*it)->stream != p.stream())
2489*da0073e9SAndroid Build Coastguard Worker return false;
2490*da0073e9SAndroid Build Coastguard Worker
2491*da0073e9SAndroid Build Coastguard Worker if ((*it)->expandable_segment_) {
2492*da0073e9SAndroid Build Coastguard Worker if (CUDAAllocatorConfig::expandable_segments()) {
2493*da0073e9SAndroid Build Coastguard Worker // if we are allocated to the part of the block that is expandable
2494*da0073e9SAndroid Build Coastguard Worker // for the purposes of "best fit" we consider its size to be the size it
2495*da0073e9SAndroid Build Coastguard Worker // can expand to, not the size it currently is. This means that we
2496*da0073e9SAndroid Build Coastguard Worker // sometimes have to search for blocks with bigger 'size' before
2497*da0073e9SAndroid Build Coastguard Worker // choosing this segment.
2498*da0073e9SAndroid Build Coastguard Worker auto expandable_size = [](Block* b) {
2499*da0073e9SAndroid Build Coastguard Worker return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
2500*da0073e9SAndroid Build Coastguard Worker };
2501*da0073e9SAndroid Build Coastguard Worker auto next = it;
2502*da0073e9SAndroid Build Coastguard Worker next++;
2503*da0073e9SAndroid Build Coastguard Worker while ((*it)->expandable_segment_ && next != pool.blocks.end() &&
2504*da0073e9SAndroid Build Coastguard Worker (*next)->stream == p.stream() &&
2505*da0073e9SAndroid Build Coastguard Worker expandable_size(*next) < expandable_size(*it)) {
2506*da0073e9SAndroid Build Coastguard Worker it = next++;
2507*da0073e9SAndroid Build Coastguard Worker }
2508*da0073e9SAndroid Build Coastguard Worker } else {
2509*da0073e9SAndroid Build Coastguard Worker // Rarely expandable segments has been turned off after we have
2510*da0073e9SAndroid Build Coastguard Worker // already allocated some blocks as expandable. For instance,
2511*da0073e9SAndroid Build Coastguard Worker // since we cannot share expandable memory via IPC, someone might
2512*da0073e9SAndroid Build Coastguard Worker // temporarily disable it. In this case we need to honor this request
2513*da0073e9SAndroid Build Coastguard Worker // by only finding non-expandable blocks
2514*da0073e9SAndroid Build Coastguard Worker do {
2515*da0073e9SAndroid Build Coastguard Worker it++;
2516*da0073e9SAndroid Build Coastguard Worker } while (it != pool.blocks.end() && (*it)->expandable_segment_ &&
2517*da0073e9SAndroid Build Coastguard Worker (*it)->stream == p.stream());
2518*da0073e9SAndroid Build Coastguard Worker if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
2519*da0073e9SAndroid Build Coastguard Worker return false;
2520*da0073e9SAndroid Build Coastguard Worker }
2521*da0073e9SAndroid Build Coastguard Worker }
2522*da0073e9SAndroid Build Coastguard Worker }
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker // Do not return an oversized block for a large request
2525*da0073e9SAndroid Build Coastguard Worker if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
2526*da0073e9SAndroid Build Coastguard Worker ((*it)->size >= CUDAAllocatorConfig::max_split_size()))
2527*da0073e9SAndroid Build Coastguard Worker return false;
2528*da0073e9SAndroid Build Coastguard Worker // Allow oversized block size to be rounded up but within a limit
2529*da0073e9SAndroid Build Coastguard Worker if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
2530*da0073e9SAndroid Build Coastguard Worker ((*it)->size >= p.size() + kLargeBuffer))
2531*da0073e9SAndroid Build Coastguard Worker return false;
2532*da0073e9SAndroid Build Coastguard Worker p.block = *it;
2533*da0073e9SAndroid Build Coastguard Worker pool.blocks.erase(it);
2534*da0073e9SAndroid Build Coastguard Worker return true;
2535*da0073e9SAndroid Build Coastguard Worker }
2536*da0073e9SAndroid Build Coastguard Worker
trigger_free_memory_callbacks(AllocParams & p)2537*da0073e9SAndroid Build Coastguard Worker bool trigger_free_memory_callbacks(AllocParams& p) {
2538*da0073e9SAndroid Build Coastguard Worker bool freed_memory = false;
2539*da0073e9SAndroid Build Coastguard Worker for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) {
2540*da0073e9SAndroid Build Coastguard Worker freed_memory |=
2541*da0073e9SAndroid Build Coastguard Worker FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute();
2542*da0073e9SAndroid Build Coastguard Worker }
2543*da0073e9SAndroid Build Coastguard Worker return freed_memory;
2544*da0073e9SAndroid Build Coastguard Worker }
2545*da0073e9SAndroid Build Coastguard Worker
garbage_collect_cached_blocks(const std::shared_ptr<GatheredContext> & context)2546*da0073e9SAndroid Build Coastguard Worker void garbage_collect_cached_blocks(
2547*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
2548*da0073e9SAndroid Build Coastguard Worker // Free unused cached blocks to reclaim GPU memory.
2549*da0073e9SAndroid Build Coastguard Worker // Unlike release_cached_blocks(), this does not enforce synchronization and
2550*da0073e9SAndroid Build Coastguard Worker // therefore should be of less overheads.
2551*da0073e9SAndroid Build Coastguard Worker
2552*da0073e9SAndroid Build Coastguard Worker size_t gc_threshold = static_cast<size_t>(
2553*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::garbage_collection_threshold() *
2554*da0073e9SAndroid Build Coastguard Worker static_cast<double>(allowed_memory_maximum));
2555*da0073e9SAndroid Build Coastguard Worker // No need to trigger GC yet
2556*da0073e9SAndroid Build Coastguard Worker if (total_allocated_memory <= gc_threshold) {
2557*da0073e9SAndroid Build Coastguard Worker return;
2558*da0073e9SAndroid Build Coastguard Worker }
2559*da0073e9SAndroid Build Coastguard Worker const auto target_size = total_allocated_memory - gc_threshold;
2560*da0073e9SAndroid Build Coastguard Worker size_t gc_reclaimed = 0;
2561*da0073e9SAndroid Build Coastguard Worker
2562*da0073e9SAndroid Build Coastguard Worker // Calculate the total age of the free-able blocks. We'll use it later to
2563*da0073e9SAndroid Build Coastguard Worker // get "avg age" threshold.
2564*da0073e9SAndroid Build Coastguard Worker size_t total_age = 0.0;
2565*da0073e9SAndroid Build Coastguard Worker int freeable_block_count = 0;
2566*da0073e9SAndroid Build Coastguard Worker for (auto& b : large_blocks.blocks) {
2567*da0073e9SAndroid Build Coastguard Worker if (!b->is_split()) {
2568*da0073e9SAndroid Build Coastguard Worker total_age += b->gc_count();
2569*da0073e9SAndroid Build Coastguard Worker ++freeable_block_count;
2570*da0073e9SAndroid Build Coastguard Worker }
2571*da0073e9SAndroid Build Coastguard Worker }
2572*da0073e9SAndroid Build Coastguard Worker // No free-able blocks?
2573*da0073e9SAndroid Build Coastguard Worker if (freeable_block_count == 0) {
2574*da0073e9SAndroid Build Coastguard Worker return;
2575*da0073e9SAndroid Build Coastguard Worker }
2576*da0073e9SAndroid Build Coastguard Worker
2577*da0073e9SAndroid Build Coastguard Worker // Repeat GC until we reach reclaim > target size.
2578*da0073e9SAndroid Build Coastguard Worker bool block_freed = true;
2579*da0073e9SAndroid Build Coastguard Worker while (gc_reclaimed < target_size && block_freed == true &&
2580*da0073e9SAndroid Build Coastguard Worker freeable_block_count > 0) {
2581*da0073e9SAndroid Build Coastguard Worker // Free blocks exceeding this age threshold first.
2582*da0073e9SAndroid Build Coastguard Worker double age_threshold =
2583*da0073e9SAndroid Build Coastguard Worker static_cast<double>(total_age) / freeable_block_count;
2584*da0073e9SAndroid Build Coastguard Worker // Stop iteration if we can no longer free a block.
2585*da0073e9SAndroid Build Coastguard Worker block_freed = false;
2586*da0073e9SAndroid Build Coastguard Worker
2587*da0073e9SAndroid Build Coastguard Worker // Free blocks of > avg age. Don't stop upon reaching the target_size,
2588*da0073e9SAndroid Build Coastguard Worker // we don't want this GC to be triggered frequently.
2589*da0073e9SAndroid Build Coastguard Worker auto it = large_blocks.blocks.begin();
2590*da0073e9SAndroid Build Coastguard Worker while (it != large_blocks.blocks.end()) {
2591*da0073e9SAndroid Build Coastguard Worker Block* block = *it;
2592*da0073e9SAndroid Build Coastguard Worker ++it;
2593*da0073e9SAndroid Build Coastguard Worker if (!block->is_split() && !block->expandable_segment_ &&
2594*da0073e9SAndroid Build Coastguard Worker static_cast<double>(block->gc_count()) >= age_threshold) {
2595*da0073e9SAndroid Build Coastguard Worker block_freed = true;
2596*da0073e9SAndroid Build Coastguard Worker gc_reclaimed += block->size;
2597*da0073e9SAndroid Build Coastguard Worker total_age -= block->gc_count(); // Decrement the age
2598*da0073e9SAndroid Build Coastguard Worker freeable_block_count--; // One less block that can be freed
2599*da0073e9SAndroid Build Coastguard Worker release_block(block, context);
2600*da0073e9SAndroid Build Coastguard Worker }
2601*da0073e9SAndroid Build Coastguard Worker }
2602*da0073e9SAndroid Build Coastguard Worker }
2603*da0073e9SAndroid Build Coastguard Worker }
2604*da0073e9SAndroid Build Coastguard Worker
2605*da0073e9SAndroid Build Coastguard Worker // This function assumes that global lock has been taken whle calling into
2606*da0073e9SAndroid Build Coastguard Worker // this function. We do cudaMalloc sync call in this function which
2607*da0073e9SAndroid Build Coastguard Worker // can be expensive while holding the lock. Hence, we pass-in the lock to the
2608*da0073e9SAndroid Build Coastguard Worker // function to temporarily release the lock before cudaMalloc call and acquire
2609*da0073e9SAndroid Build Coastguard Worker // it back again after the call so that other threads dont get blocked.
alloc_block(AllocParams & p,bool isRetry,const std::shared_ptr<GatheredContext> & ctx,std::unique_lock<std::recursive_mutex> & lock)2610*da0073e9SAndroid Build Coastguard Worker bool alloc_block(
2611*da0073e9SAndroid Build Coastguard Worker AllocParams& p,
2612*da0073e9SAndroid Build Coastguard Worker bool isRetry,
2613*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& ctx,
2614*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::recursive_mutex>& lock) {
2615*da0073e9SAndroid Build Coastguard Worker // Defensively checks for preexisting CUDA error state.
2616*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaGetLastError());
2617*da0073e9SAndroid Build Coastguard Worker
2618*da0073e9SAndroid Build Coastguard Worker size_t size = p.alloc_size;
2619*da0073e9SAndroid Build Coastguard Worker void* ptr = nullptr;
2620*da0073e9SAndroid Build Coastguard Worker
2621*da0073e9SAndroid Build Coastguard Worker if (isRetry) {
2622*da0073e9SAndroid Build Coastguard Worker stats.num_alloc_retries += 1;
2623*da0073e9SAndroid Build Coastguard Worker }
2624*da0073e9SAndroid Build Coastguard Worker #ifdef FBCODE_CAFFE2
2625*da0073e9SAndroid Build Coastguard Worker bool in_fbcode = true;
2626*da0073e9SAndroid Build Coastguard Worker #else
2627*da0073e9SAndroid Build Coastguard Worker bool in_fbcode = false;
2628*da0073e9SAndroid Build Coastguard Worker #endif
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker if (set_fraction &&
2631*da0073e9SAndroid Build Coastguard Worker total_allocated_memory + size > allowed_memory_maximum) {
2632*da0073e9SAndroid Build Coastguard Worker p.err = cudaErrorMemoryAllocation;
2633*da0073e9SAndroid Build Coastguard Worker return false;
2634*da0073e9SAndroid Build Coastguard Worker // Temporarily disable checkpointing & cudagraphs internally
2635*da0073e9SAndroid Build Coastguard Worker } else if (
2636*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::expandable_segments() &&
2637*da0073e9SAndroid Build Coastguard Worker !(in_fbcode && p.pool->owner_PrivatePool)) {
2638*da0073e9SAndroid Build Coastguard Worker p.block = try_allocate_expandable_block(
2639*da0073e9SAndroid Build Coastguard Worker p.device(), p.stream(), p.pool, p.size(), ctx);
2640*da0073e9SAndroid Build Coastguard Worker if (p.block) {
2641*da0073e9SAndroid Build Coastguard Worker p.err = cudaSuccess;
2642*da0073e9SAndroid Build Coastguard Worker if (p.pool->owner_PrivatePool) {
2643*da0073e9SAndroid Build Coastguard Worker // The block is for a CUDA graph's PrivatePool.
2644*da0073e9SAndroid Build Coastguard Worker p.pool->owner_PrivatePool->cudaMalloc_count++;
2645*da0073e9SAndroid Build Coastguard Worker }
2646*da0073e9SAndroid Build Coastguard Worker } else {
2647*da0073e9SAndroid Build Coastguard Worker p.err = cudaErrorMemoryAllocation;
2648*da0073e9SAndroid Build Coastguard Worker }
2649*da0073e9SAndroid Build Coastguard Worker return bool(p.block);
2650*da0073e9SAndroid Build Coastguard Worker } else {
2651*da0073e9SAndroid Build Coastguard Worker if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
2652*da0073e9SAndroid Build Coastguard Worker // At scope exit, acquire the lock again. This provides safety against
2653*da0073e9SAndroid Build Coastguard Worker // any potential exceptions in the cudaMallocMaybeCapturing function.
2654*da0073e9SAndroid Build Coastguard Worker auto sg = c10::make_scope_exit([&]() { lock.lock(); });
2655*da0073e9SAndroid Build Coastguard Worker lock.unlock();
2656*da0073e9SAndroid Build Coastguard Worker }
2657*da0073e9SAndroid Build Coastguard Worker auto active_pool = MemPoolContext::getActiveMemPool();
2658*da0073e9SAndroid Build Coastguard Worker if (active_pool && active_pool->allocator() &&
2659*da0073e9SAndroid Build Coastguard Worker p.pool->owner_PrivatePool) {
2660*da0073e9SAndroid Build Coastguard Worker ptr = active_pool->allocator()->raw_alloc(size);
2661*da0073e9SAndroid Build Coastguard Worker p.err = ptr ? cudaSuccess : cudaErrorMemoryAllocation;
2662*da0073e9SAndroid Build Coastguard Worker } else {
2663*da0073e9SAndroid Build Coastguard Worker p.err = cudaMallocMaybeCapturing(&ptr, size);
2664*da0073e9SAndroid Build Coastguard Worker }
2665*da0073e9SAndroid Build Coastguard Worker if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
2666*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
2667*da0073e9SAndroid Build Coastguard Worker lock.owns_lock(), "Failed to acquire lock after cudaMalloc");
2668*da0073e9SAndroid Build Coastguard Worker }
2669*da0073e9SAndroid Build Coastguard Worker
2670*da0073e9SAndroid Build Coastguard Worker if (p.err != cudaSuccess) {
2671*da0073e9SAndroid Build Coastguard Worker if (p.err == cudaErrorMemoryAllocation) {
2672*da0073e9SAndroid Build Coastguard Worker // If this is the first attempt (!isRetry), we can forgive and clear
2673*da0073e9SAndroid Build Coastguard Worker // CUDA's internal error state.
2674*da0073e9SAndroid Build Coastguard Worker //
2675*da0073e9SAndroid Build Coastguard Worker // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
2676*da0073e9SAndroid Build Coastguard Worker // will take over to throw a helpful exception. The user can choose
2677*da0073e9SAndroid Build Coastguard Worker // to catch the exception, free some stuff in their script, and
2678*da0073e9SAndroid Build Coastguard Worker // attempt the allocation again. In this case, we can also forgive and
2679*da0073e9SAndroid Build Coastguard Worker // clear CUDA's internal error state.
2680*da0073e9SAndroid Build Coastguard Worker (void)cudaGetLastError();
2681*da0073e9SAndroid Build Coastguard Worker } else {
2682*da0073e9SAndroid Build Coastguard Worker // If the error's unrelated to memory allocation, we should throw
2683*da0073e9SAndroid Build Coastguard Worker // immediately.
2684*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(p.err);
2685*da0073e9SAndroid Build Coastguard Worker }
2686*da0073e9SAndroid Build Coastguard Worker return false;
2687*da0073e9SAndroid Build Coastguard Worker }
2688*da0073e9SAndroid Build Coastguard Worker }
2689*da0073e9SAndroid Build Coastguard Worker
2690*da0073e9SAndroid Build Coastguard Worker if (p.pool->owner_PrivatePool) {
2691*da0073e9SAndroid Build Coastguard Worker // The block is for a CUDA graph's PrivatePool.
2692*da0073e9SAndroid Build Coastguard Worker p.pool->owner_PrivatePool->cudaMalloc_count++;
2693*da0073e9SAndroid Build Coastguard Worker }
2694*da0073e9SAndroid Build Coastguard Worker
2695*da0073e9SAndroid Build Coastguard Worker total_allocated_memory += size;
2696*da0073e9SAndroid Build Coastguard Worker p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
2697*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
2698*da0073e9SAndroid Build Coastguard Worker stats.segment[stat_type].increase(1);
2699*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[stat_type].increase(size);
2700*da0073e9SAndroid Build Coastguard Worker });
2701*da0073e9SAndroid Build Coastguard Worker if (size >= CUDAAllocatorConfig::max_split_size())
2702*da0073e9SAndroid Build Coastguard Worker stats.oversize_segments.increase(1);
2703*da0073e9SAndroid Build Coastguard Worker auto reserved_bytes_gauge =
2704*da0073e9SAndroid Build Coastguard Worker STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
2705*da0073e9SAndroid Build Coastguard Worker reserved_bytes_gauge.record(
2706*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
2707*da0073e9SAndroid Build Coastguard Worker .current);
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker // p.block came from new, not cudaMalloc. It should not be nullptr here.
2710*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
2711*da0073e9SAndroid Build Coastguard Worker stats.num_device_alloc++;
2712*da0073e9SAndroid Build Coastguard Worker record_trace(
2713*da0073e9SAndroid Build Coastguard Worker TraceEntry::SEGMENT_ALLOC,
2714*da0073e9SAndroid Build Coastguard Worker int64_t(p.block->ptr),
2715*da0073e9SAndroid Build Coastguard Worker p.block->size,
2716*da0073e9SAndroid Build Coastguard Worker p.stream(),
2717*da0073e9SAndroid Build Coastguard Worker p.device(),
2718*da0073e9SAndroid Build Coastguard Worker ctx);
2719*da0073e9SAndroid Build Coastguard Worker p.block->context_when_segment_allocated = ctx;
2720*da0073e9SAndroid Build Coastguard Worker return true;
2721*da0073e9SAndroid Build Coastguard Worker }
2722*da0073e9SAndroid Build Coastguard Worker
2723*da0073e9SAndroid Build Coastguard Worker /** Free one or more oversize blocks to the system allocator. But only enough
2724*da0073e9SAndroid Build Coastguard Worker * **/
2725*da0073e9SAndroid Build Coastguard Worker /** to satisfy the target size **/
release_available_cached_blocks(const AllocParams & p,const std::shared_ptr<GatheredContext> & context)2726*da0073e9SAndroid Build Coastguard Worker bool release_available_cached_blocks(
2727*da0073e9SAndroid Build Coastguard Worker const AllocParams& p,
2728*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
2729*da0073e9SAndroid Build Coastguard Worker if (CUDAAllocatorConfig::max_split_size() ==
2730*da0073e9SAndroid Build Coastguard Worker std::numeric_limits<size_t>::max())
2731*da0073e9SAndroid Build Coastguard Worker return false;
2732*da0073e9SAndroid Build Coastguard Worker BlockPool& pool = *p.pool;
2733*da0073e9SAndroid Build Coastguard Worker
2734*da0073e9SAndroid Build Coastguard Worker // because of std::unique_ptr, block cannot be trivially copied
2735*da0073e9SAndroid Build Coastguard Worker // Use constructor for search key.
2736*da0073e9SAndroid Build Coastguard Worker Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
2737*da0073e9SAndroid Build Coastguard Worker key.size = (key.size < CUDAAllocatorConfig::max_split_size())
2738*da0073e9SAndroid Build Coastguard Worker ? CUDAAllocatorConfig::max_split_size()
2739*da0073e9SAndroid Build Coastguard Worker : key.size;
2740*da0073e9SAndroid Build Coastguard Worker auto it = pool.blocks.lower_bound(&key);
2741*da0073e9SAndroid Build Coastguard Worker if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
2742*da0073e9SAndroid Build Coastguard Worker (*it)->expandable_segment_) {
2743*da0073e9SAndroid Build Coastguard Worker // No single block is large enough; free multiple oversize blocks,
2744*da0073e9SAndroid Build Coastguard Worker // starting with the largest
2745*da0073e9SAndroid Build Coastguard Worker if (it == pool.blocks.begin())
2746*da0073e9SAndroid Build Coastguard Worker return false;
2747*da0073e9SAndroid Build Coastguard Worker size_t totalReleased = 0;
2748*da0073e9SAndroid Build Coastguard Worker --it; // Back up one item. Now on the largest block for the correct
2749*da0073e9SAndroid Build Coastguard Worker // stream
2750*da0073e9SAndroid Build Coastguard Worker while ((totalReleased < key.size) &&
2751*da0073e9SAndroid Build Coastguard Worker ((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
2752*da0073e9SAndroid Build Coastguard Worker ((*it)->stream == p.stream())) {
2753*da0073e9SAndroid Build Coastguard Worker auto cur = it;
2754*da0073e9SAndroid Build Coastguard Worker bool is_first = cur == pool.blocks.begin();
2755*da0073e9SAndroid Build Coastguard Worker if (!is_first) {
2756*da0073e9SAndroid Build Coastguard Worker --it;
2757*da0073e9SAndroid Build Coastguard Worker }
2758*da0073e9SAndroid Build Coastguard Worker if (!(*cur)->expandable_segment_) {
2759*da0073e9SAndroid Build Coastguard Worker release_block(*cur, context);
2760*da0073e9SAndroid Build Coastguard Worker totalReleased += (*cur)->size;
2761*da0073e9SAndroid Build Coastguard Worker }
2762*da0073e9SAndroid Build Coastguard Worker if (is_first) {
2763*da0073e9SAndroid Build Coastguard Worker break;
2764*da0073e9SAndroid Build Coastguard Worker }
2765*da0073e9SAndroid Build Coastguard Worker }
2766*da0073e9SAndroid Build Coastguard Worker if (totalReleased < key.size)
2767*da0073e9SAndroid Build Coastguard Worker return false;
2768*da0073e9SAndroid Build Coastguard Worker } else {
2769*da0073e9SAndroid Build Coastguard Worker release_block(*it, context);
2770*da0073e9SAndroid Build Coastguard Worker }
2771*da0073e9SAndroid Build Coastguard Worker return true;
2772*da0073e9SAndroid Build Coastguard Worker }
2773*da0073e9SAndroid Build Coastguard Worker
release_cached_blocks(const std::shared_ptr<GatheredContext> & context)2774*da0073e9SAndroid Build Coastguard Worker bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) {
2775*da0073e9SAndroid Build Coastguard Worker // First ensure that all blocks that can't currently be allocated due to
2776*da0073e9SAndroid Build Coastguard Worker // outstanding events are returned to the pool.
2777*da0073e9SAndroid Build Coastguard Worker synchronize_and_free_events(context);
2778*da0073e9SAndroid Build Coastguard Worker
2779*da0073e9SAndroid Build Coastguard Worker // Free all non-split cached blocks to system allocator
2780*da0073e9SAndroid Build Coastguard Worker release_blocks(large_blocks, context);
2781*da0073e9SAndroid Build Coastguard Worker release_blocks(small_blocks, context);
2782*da0073e9SAndroid Build Coastguard Worker
2783*da0073e9SAndroid Build Coastguard Worker for (auto it = graph_pools_freeable.begin();
2784*da0073e9SAndroid Build Coastguard Worker it != graph_pools_freeable.end();) {
2785*da0073e9SAndroid Build Coastguard Worker // See notifyCaptureDestroy for the strategy here.
2786*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
2787*da0073e9SAndroid Build Coastguard Worker release_blocks(it->second->small_blocks, context);
2788*da0073e9SAndroid Build Coastguard Worker release_blocks(it->second->large_blocks, context);
2789*da0073e9SAndroid Build Coastguard Worker if (it->second->cudaMalloc_count == 0) {
2790*da0073e9SAndroid Build Coastguard Worker auto erase_count = graph_pools.erase(it->first);
2791*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(erase_count == 1);
2792*da0073e9SAndroid Build Coastguard Worker it = graph_pools_freeable.erase(it);
2793*da0073e9SAndroid Build Coastguard Worker } else {
2794*da0073e9SAndroid Build Coastguard Worker ++it;
2795*da0073e9SAndroid Build Coastguard Worker }
2796*da0073e9SAndroid Build Coastguard Worker }
2797*da0073e9SAndroid Build Coastguard Worker
2798*da0073e9SAndroid Build Coastguard Worker return true;
2799*da0073e9SAndroid Build Coastguard Worker }
2800*da0073e9SAndroid Build Coastguard Worker
release_expandable_segment(Block * block)2801*da0073e9SAndroid Build Coastguard Worker void release_expandable_segment(Block* block) {
2802*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
2803*da0073e9SAndroid Build Coastguard Worker block->size == block->expandable_segment_->size(),
2804*da0073e9SAndroid Build Coastguard Worker "block disagrees with segment");
2805*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!block->mapped);
2806*da0073e9SAndroid Build Coastguard Worker auto it = std::find(
2807*da0073e9SAndroid Build Coastguard Worker expandable_segments_.begin(),
2808*da0073e9SAndroid Build Coastguard Worker expandable_segments_.end(),
2809*da0073e9SAndroid Build Coastguard Worker block->expandable_segment_);
2810*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(it != expandable_segments_.end());
2811*da0073e9SAndroid Build Coastguard Worker expandable_segments_.erase(it);
2812*da0073e9SAndroid Build Coastguard Worker block->pool->unmapped.erase(block);
2813*da0073e9SAndroid Build Coastguard Worker delete block->expandable_segment_;
2814*da0073e9SAndroid Build Coastguard Worker delete block;
2815*da0073e9SAndroid Build Coastguard Worker }
2816*da0073e9SAndroid Build Coastguard Worker
release_block(Block * block,const std::shared_ptr<GatheredContext> & context)2817*da0073e9SAndroid Build Coastguard Worker void release_block(
2818*da0073e9SAndroid Build Coastguard Worker Block* block,
2819*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
2820*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!block->expandable_segment_);
2821*da0073e9SAndroid Build Coastguard Worker stats.num_device_free++;
2822*da0073e9SAndroid Build Coastguard Worker record_trace(
2823*da0073e9SAndroid Build Coastguard Worker TraceEntry::SEGMENT_FREE,
2824*da0073e9SAndroid Build Coastguard Worker int64_t(block->ptr),
2825*da0073e9SAndroid Build Coastguard Worker block->size,
2826*da0073e9SAndroid Build Coastguard Worker block->stream,
2827*da0073e9SAndroid Build Coastguard Worker block->device,
2828*da0073e9SAndroid Build Coastguard Worker context ? context : block->context_when_segment_allocated);
2829*da0073e9SAndroid Build Coastguard Worker
2830*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaFree((void*)block->ptr));
2831*da0073e9SAndroid Build Coastguard Worker total_allocated_memory -= block->size;
2832*da0073e9SAndroid Build Coastguard Worker
2833*da0073e9SAndroid Build Coastguard Worker auto* pool = block->pool;
2834*da0073e9SAndroid Build Coastguard Worker if (pool->owner_PrivatePool) {
2835*da0073e9SAndroid Build Coastguard Worker // The cudaFreed block belonged to a CUDA graph's PrivatePool.
2836*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0);
2837*da0073e9SAndroid Build Coastguard Worker pool->owner_PrivatePool->cudaMalloc_count--;
2838*da0073e9SAndroid Build Coastguard Worker }
2839*da0073e9SAndroid Build Coastguard Worker
2840*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(*pool);
2841*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
2842*da0073e9SAndroid Build Coastguard Worker stats.segment[stat_type].decrease(1);
2843*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[stat_type].decrease(block->size);
2844*da0073e9SAndroid Build Coastguard Worker });
2845*da0073e9SAndroid Build Coastguard Worker auto reserved_bytes_gauge =
2846*da0073e9SAndroid Build Coastguard Worker STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
2847*da0073e9SAndroid Build Coastguard Worker reserved_bytes_gauge.record(
2848*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
2849*da0073e9SAndroid Build Coastguard Worker .current);
2850*da0073e9SAndroid Build Coastguard Worker
2851*da0073e9SAndroid Build Coastguard Worker if (block->size >= CUDAAllocatorConfig::max_split_size())
2852*da0073e9SAndroid Build Coastguard Worker stats.oversize_segments.decrease(1);
2853*da0073e9SAndroid Build Coastguard Worker pool->blocks.erase(block);
2854*da0073e9SAndroid Build Coastguard Worker delete block;
2855*da0073e9SAndroid Build Coastguard Worker }
2856*da0073e9SAndroid Build Coastguard Worker
unmap_block(Block * block,const std::shared_ptr<GatheredContext> & context)2857*da0073e9SAndroid Build Coastguard Worker void unmap_block(
2858*da0073e9SAndroid Build Coastguard Worker Block* block,
2859*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
2860*da0073e9SAndroid Build Coastguard Worker auto unmapped = block->expandable_segment_->unmap(
2861*da0073e9SAndroid Build Coastguard Worker SegmentRange{block->ptr, block->size});
2862*da0073e9SAndroid Build Coastguard Worker if (unmapped.size == 0) {
2863*da0073e9SAndroid Build Coastguard Worker return;
2864*da0073e9SAndroid Build Coastguard Worker }
2865*da0073e9SAndroid Build Coastguard Worker block->pool->blocks.erase(block);
2866*da0073e9SAndroid Build Coastguard Worker
2867*da0073e9SAndroid Build Coastguard Worker ptrdiff_t before_size =
2868*da0073e9SAndroid Build Coastguard Worker static_cast<char*>(unmapped.ptr) - static_cast<char*>(block->ptr);
2869*da0073e9SAndroid Build Coastguard Worker if (before_size > 0) {
2870*da0073e9SAndroid Build Coastguard Worker // prev? -> before_free -> block
2871*da0073e9SAndroid Build Coastguard Worker Block* before_free = new Block(
2872*da0073e9SAndroid Build Coastguard Worker block->device, block->stream, before_size, block->pool, block->ptr);
2873*da0073e9SAndroid Build Coastguard Worker before_free->expandable_segment_ = block->expandable_segment_;
2874*da0073e9SAndroid Build Coastguard Worker before_free->splice(block->prev, block);
2875*da0073e9SAndroid Build Coastguard Worker block->pool->insert_into_blocks(before_free);
2876*da0073e9SAndroid Build Coastguard Worker }
2877*da0073e9SAndroid Build Coastguard Worker
2878*da0073e9SAndroid Build Coastguard Worker auto after_size = block->size - (before_size + unmapped.size);
2879*da0073e9SAndroid Build Coastguard Worker if (after_size > 0) {
2880*da0073e9SAndroid Build Coastguard Worker // block -> after_free -> next?
2881*da0073e9SAndroid Build Coastguard Worker Block* after_free = new Block(
2882*da0073e9SAndroid Build Coastguard Worker block->device,
2883*da0073e9SAndroid Build Coastguard Worker block->stream,
2884*da0073e9SAndroid Build Coastguard Worker after_size,
2885*da0073e9SAndroid Build Coastguard Worker block->pool,
2886*da0073e9SAndroid Build Coastguard Worker static_cast<char*>(unmapped.ptr) + unmapped.size);
2887*da0073e9SAndroid Build Coastguard Worker after_free->expandable_segment_ = block->expandable_segment_;
2888*da0073e9SAndroid Build Coastguard Worker after_free->splice(block, block->next);
2889*da0073e9SAndroid Build Coastguard Worker block->pool->insert_into_blocks(after_free);
2890*da0073e9SAndroid Build Coastguard Worker }
2891*da0073e9SAndroid Build Coastguard Worker
2892*da0073e9SAndroid Build Coastguard Worker block->ptr = unmapped.ptr;
2893*da0073e9SAndroid Build Coastguard Worker block->size = unmapped.size;
2894*da0073e9SAndroid Build Coastguard Worker block->mapped = false;
2895*da0073e9SAndroid Build Coastguard Worker
2896*da0073e9SAndroid Build Coastguard Worker try_merge_blocks(block, block->prev, *block->pool);
2897*da0073e9SAndroid Build Coastguard Worker try_merge_blocks(block, block->next, *block->pool);
2898*da0073e9SAndroid Build Coastguard Worker block->pool->unmapped.insert(block);
2899*da0073e9SAndroid Build Coastguard Worker
2900*da0073e9SAndroid Build Coastguard Worker // update statistics
2901*da0073e9SAndroid Build Coastguard Worker total_allocated_memory -= unmapped.size;
2902*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(*block->pool);
2903*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
2904*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[stat_type].decrease(unmapped.size);
2905*da0073e9SAndroid Build Coastguard Worker });
2906*da0073e9SAndroid Build Coastguard Worker auto reserved_bytes_gauge =
2907*da0073e9SAndroid Build Coastguard Worker STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
2908*da0073e9SAndroid Build Coastguard Worker reserved_bytes_gauge.record(
2909*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
2910*da0073e9SAndroid Build Coastguard Worker .current);
2911*da0073e9SAndroid Build Coastguard Worker
2912*da0073e9SAndroid Build Coastguard Worker if (block->pool->owner_PrivatePool) {
2913*da0073e9SAndroid Build Coastguard Worker // The cudaFreed block belonged to a CUDA graph's PrivatePool.
2914*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
2915*da0073e9SAndroid Build Coastguard Worker block->pool->owner_PrivatePool->cudaMalloc_count > 0);
2916*da0073e9SAndroid Build Coastguard Worker block->pool->owner_PrivatePool->cudaMalloc_count--;
2917*da0073e9SAndroid Build Coastguard Worker }
2918*da0073e9SAndroid Build Coastguard Worker
2919*da0073e9SAndroid Build Coastguard Worker stats.num_device_free++;
2920*da0073e9SAndroid Build Coastguard Worker record_trace(
2921*da0073e9SAndroid Build Coastguard Worker TraceEntry::SEGMENT_UNMAP,
2922*da0073e9SAndroid Build Coastguard Worker int64_t(unmapped.ptr),
2923*da0073e9SAndroid Build Coastguard Worker unmapped.size,
2924*da0073e9SAndroid Build Coastguard Worker block->stream,
2925*da0073e9SAndroid Build Coastguard Worker block->device,
2926*da0073e9SAndroid Build Coastguard Worker context ? context : block->context_when_segment_allocated);
2927*da0073e9SAndroid Build Coastguard Worker }
release_blocks(BlockPool & pool,const std::shared_ptr<GatheredContext> & context)2928*da0073e9SAndroid Build Coastguard Worker void release_blocks(
2929*da0073e9SAndroid Build Coastguard Worker BlockPool& pool,
2930*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
2931*da0073e9SAndroid Build Coastguard Worker std::vector<Block*> to_unmap;
2932*da0073e9SAndroid Build Coastguard Worker // Frees all non-split blocks
2933*da0073e9SAndroid Build Coastguard Worker auto it = pool.blocks.begin();
2934*da0073e9SAndroid Build Coastguard Worker while (it != pool.blocks.end()) {
2935*da0073e9SAndroid Build Coastguard Worker Block* block = *it;
2936*da0073e9SAndroid Build Coastguard Worker ++it;
2937*da0073e9SAndroid Build Coastguard Worker if (block->expandable_segment_) {
2938*da0073e9SAndroid Build Coastguard Worker // unmapping will mutate the free pool
2939*da0073e9SAndroid Build Coastguard Worker // so just gather what needs to be freed
2940*da0073e9SAndroid Build Coastguard Worker // to avoid invalidating the iterator
2941*da0073e9SAndroid Build Coastguard Worker to_unmap.push_back(block);
2942*da0073e9SAndroid Build Coastguard Worker } else if (!block->prev && !block->next) {
2943*da0073e9SAndroid Build Coastguard Worker release_block(block, context);
2944*da0073e9SAndroid Build Coastguard Worker }
2945*da0073e9SAndroid Build Coastguard Worker }
2946*da0073e9SAndroid Build Coastguard Worker for (Block* block : to_unmap) {
2947*da0073e9SAndroid Build Coastguard Worker unmap_block(block, context);
2948*da0073e9SAndroid Build Coastguard Worker if (!block->prev && !block->next) {
2949*da0073e9SAndroid Build Coastguard Worker release_expandable_segment(block);
2950*da0073e9SAndroid Build Coastguard Worker }
2951*da0073e9SAndroid Build Coastguard Worker }
2952*da0073e9SAndroid Build Coastguard Worker }
2953*da0073e9SAndroid Build Coastguard Worker
create_event_internal(c10::DeviceIndex idx)2954*da0073e9SAndroid Build Coastguard Worker EventPool::Event create_event_internal(c10::DeviceIndex idx) {
2955*da0073e9SAndroid Build Coastguard Worker // Leak the event pool to avoid shutdown issues.
2956*da0073e9SAndroid Build Coastguard Worker static auto* event_pool = new EventPool();
2957*da0073e9SAndroid Build Coastguard Worker return event_pool->get(idx);
2958*da0073e9SAndroid Build Coastguard Worker }
2959*da0073e9SAndroid Build Coastguard Worker
synchronize_and_free_events(const std::shared_ptr<GatheredContext> & context)2960*da0073e9SAndroid Build Coastguard Worker void synchronize_and_free_events(
2961*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
2962*da0073e9SAndroid Build Coastguard Worker // Synchronize on outstanding events and then free associated blocks.
2963*da0073e9SAndroid Build Coastguard Worker stats.num_sync_all_streams++;
2964*da0073e9SAndroid Build Coastguard Worker
2965*da0073e9SAndroid Build Coastguard Worker // This function syncs, so capture should not be underway. Might as well
2966*da0073e9SAndroid Build Coastguard Worker // make sure capture-deferred end of life events get processed too.
2967*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(captures_underway.empty());
2968*da0073e9SAndroid Build Coastguard Worker insert_events_deferred_until_no_capture(context);
2969*da0073e9SAndroid Build Coastguard Worker
2970*da0073e9SAndroid Build Coastguard Worker for (auto& st : cuda_events) {
2971*da0073e9SAndroid Build Coastguard Worker for (auto& e : st.second) {
2972*da0073e9SAndroid Build Coastguard Worker EventPool::Event event = std::move(e.first);
2973*da0073e9SAndroid Build Coastguard Worker Block* block = e.second;
2974*da0073e9SAndroid Build Coastguard Worker
2975*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaEventSynchronize(*event));
2976*da0073e9SAndroid Build Coastguard Worker
2977*da0073e9SAndroid Build Coastguard Worker block->event_count--;
2978*da0073e9SAndroid Build Coastguard Worker if (block->event_count == 0) {
2979*da0073e9SAndroid Build Coastguard Worker free_block(block, context);
2980*da0073e9SAndroid Build Coastguard Worker }
2981*da0073e9SAndroid Build Coastguard Worker }
2982*da0073e9SAndroid Build Coastguard Worker }
2983*da0073e9SAndroid Build Coastguard Worker
2984*da0073e9SAndroid Build Coastguard Worker cuda_events.clear();
2985*da0073e9SAndroid Build Coastguard Worker }
2986*da0073e9SAndroid Build Coastguard Worker
remove_cudagraph_stream_uses(Block * block)2987*da0073e9SAndroid Build Coastguard Worker void remove_cudagraph_stream_uses(Block* block) {
2988*da0073e9SAndroid Build Coastguard Worker // remove stream uses added during cudagraph capture
2989*da0073e9SAndroid Build Coastguard Worker // (i.e., block->stream_uses - block->cudagraph_stream_uses)
2990*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(
2991*da0073e9SAndroid Build Coastguard Worker block_to_cudagraph_stream_uses.find(block) !=
2992*da0073e9SAndroid Build Coastguard Worker block_to_cudagraph_stream_uses.end())) {
2993*da0073e9SAndroid Build Coastguard Worker stream_set streams(std::move(block->stream_uses));
2994*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(block->stream_uses.empty());
2995*da0073e9SAndroid Build Coastguard Worker for (auto& stream : streams) {
2996*da0073e9SAndroid Build Coastguard Worker if (block_to_cudagraph_stream_uses[block].find(stream) ==
2997*da0073e9SAndroid Build Coastguard Worker block_to_cudagraph_stream_uses[block].end()) {
2998*da0073e9SAndroid Build Coastguard Worker block->stream_uses.insert(stream);
2999*da0073e9SAndroid Build Coastguard Worker }
3000*da0073e9SAndroid Build Coastguard Worker }
3001*da0073e9SAndroid Build Coastguard Worker block_to_cudagraph_stream_uses.erase(block);
3002*da0073e9SAndroid Build Coastguard Worker }
3003*da0073e9SAndroid Build Coastguard Worker }
3004*da0073e9SAndroid Build Coastguard Worker
insert_events(Block * block)3005*da0073e9SAndroid Build Coastguard Worker void insert_events(Block* block) {
3006*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex prev_device = 0;
3007*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));
3008*da0073e9SAndroid Build Coastguard Worker
3009*da0073e9SAndroid Build Coastguard Worker stream_set streams(std::move(block->stream_uses));
3010*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(block->stream_uses.empty());
3011*da0073e9SAndroid Build Coastguard Worker for (auto& stream : streams) {
3012*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::SetDevice(stream.device_index()));
3013*da0073e9SAndroid Build Coastguard Worker
3014*da0073e9SAndroid Build Coastguard Worker EventPool::Event event = create_event_internal(stream.device_index());
3015*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaEventRecord(*event, stream.stream()));
3016*da0073e9SAndroid Build Coastguard Worker
3017*da0073e9SAndroid Build Coastguard Worker block->event_count++;
3018*da0073e9SAndroid Build Coastguard Worker cuda_events[stream].emplace_back(std::move(event), block);
3019*da0073e9SAndroid Build Coastguard Worker }
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::MaybeSetDevice(prev_device));
3022*da0073e9SAndroid Build Coastguard Worker }
3023*da0073e9SAndroid Build Coastguard Worker
insert_events_deferred_until_no_capture(const std::shared_ptr<GatheredContext> & context)3024*da0073e9SAndroid Build Coastguard Worker void insert_events_deferred_until_no_capture(
3025*da0073e9SAndroid Build Coastguard Worker const std::shared_ptr<GatheredContext>& context) {
3026*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
3027*da0073e9SAndroid Build Coastguard Worker for (auto* block : needs_events_deferred_until_no_capture) {
3028*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
3029*da0073e9SAndroid Build Coastguard Worker // only streams recorded before cudagraph will be used to insert events
3030*da0073e9SAndroid Build Coastguard Worker // since we know all streams recorded during cudagraph must have
3031*da0073e9SAndroid Build Coastguard Worker // completed (refer to Section 3.2.8.7.3.1 Cross-stream Dependencies and
3032*da0073e9SAndroid Build Coastguard Worker // Events in CUDA Programming Guide).
3033*da0073e9SAndroid Build Coastguard Worker remove_cudagraph_stream_uses(block);
3034*da0073e9SAndroid Build Coastguard Worker insert_events(block);
3035*da0073e9SAndroid Build Coastguard Worker if (block->event_count == 0) {
3036*da0073e9SAndroid Build Coastguard Worker free_block(block, context);
3037*da0073e9SAndroid Build Coastguard Worker }
3038*da0073e9SAndroid Build Coastguard Worker }
3039*da0073e9SAndroid Build Coastguard Worker needs_events_deferred_until_no_capture.clear();
3040*da0073e9SAndroid Build Coastguard Worker }
3041*da0073e9SAndroid Build Coastguard Worker }
3042*da0073e9SAndroid Build Coastguard Worker
process_events(const std::shared_ptr<GatheredContext> & context)3043*da0073e9SAndroid Build Coastguard Worker void process_events(const std::shared_ptr<GatheredContext>& context) {
3044*da0073e9SAndroid Build Coastguard Worker insert_events_deferred_until_no_capture(context);
3045*da0073e9SAndroid Build Coastguard Worker
3046*da0073e9SAndroid Build Coastguard Worker // Process outstanding cudaEvents. Events that are completed are
3047*da0073e9SAndroid Build Coastguard Worker // removed from the queue, and the 'event_count' for the
3048*da0073e9SAndroid Build Coastguard Worker // corresponding allocation is decremented. We maintain a separate
3049*da0073e9SAndroid Build Coastguard Worker // list of events per stream to avoid head-of-line delays if one
3050*da0073e9SAndroid Build Coastguard Worker // or more streams has long-running operations.
3051*da0073e9SAndroid Build Coastguard Worker
3052*da0073e9SAndroid Build Coastguard Worker // Iterate over different streams.
3053*da0073e9SAndroid Build Coastguard Worker for (auto it = cuda_events.begin(); it != cuda_events.end();) {
3054*da0073e9SAndroid Build Coastguard Worker // Iterate over this stream's (event, block) pairs.
3055*da0073e9SAndroid Build Coastguard Worker while (!it->second.empty()) {
3056*da0073e9SAndroid Build Coastguard Worker auto& e = it->second.front();
3057*da0073e9SAndroid Build Coastguard Worker EventPool::Event event = std::move(e.first);
3058*da0073e9SAndroid Build Coastguard Worker Block* block = e.second;
3059*da0073e9SAndroid Build Coastguard Worker
3060*da0073e9SAndroid Build Coastguard Worker cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(*event));
3061*da0073e9SAndroid Build Coastguard Worker if (err == cudaErrorNotReady) {
3062*da0073e9SAndroid Build Coastguard Worker // ignore and clear the error if not ready
3063*da0073e9SAndroid Build Coastguard Worker (void)cudaGetLastError();
3064*da0073e9SAndroid Build Coastguard Worker // Return the ownership of the Event (unique ptr)
3065*da0073e9SAndroid Build Coastguard Worker e.first = std::move(event);
3066*da0073e9SAndroid Build Coastguard Worker break;
3067*da0073e9SAndroid Build Coastguard Worker } else if (err != cudaSuccess) {
3068*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(err);
3069*da0073e9SAndroid Build Coastguard Worker }
3070*da0073e9SAndroid Build Coastguard Worker
3071*da0073e9SAndroid Build Coastguard Worker block->event_count--;
3072*da0073e9SAndroid Build Coastguard Worker if (block->event_count == 0) {
3073*da0073e9SAndroid Build Coastguard Worker free_block(block, context);
3074*da0073e9SAndroid Build Coastguard Worker }
3075*da0073e9SAndroid Build Coastguard Worker it->second.pop_front();
3076*da0073e9SAndroid Build Coastguard Worker }
3077*da0073e9SAndroid Build Coastguard Worker
3078*da0073e9SAndroid Build Coastguard Worker if (it->second.empty()) {
3079*da0073e9SAndroid Build Coastguard Worker it = cuda_events.erase(it);
3080*da0073e9SAndroid Build Coastguard Worker } else {
3081*da0073e9SAndroid Build Coastguard Worker it++;
3082*da0073e9SAndroid Build Coastguard Worker }
3083*da0073e9SAndroid Build Coastguard Worker }
3084*da0073e9SAndroid Build Coastguard Worker }
3085*da0073e9SAndroid Build Coastguard Worker
3086*da0073e9SAndroid Build Coastguard Worker // Iterates over sizes of all memory blocks for given device in given pool
cache_info_aux(const BlockPool & pool,size_t * largest)3087*da0073e9SAndroid Build Coastguard Worker void cache_info_aux(const BlockPool& pool, size_t* largest) {
3088*da0073e9SAndroid Build Coastguard Worker for (const auto& block : pool.blocks) {
3089*da0073e9SAndroid Build Coastguard Worker const auto blocksize = block->size;
3090*da0073e9SAndroid Build Coastguard Worker if (blocksize > *largest) {
3091*da0073e9SAndroid Build Coastguard Worker *largest = blocksize;
3092*da0073e9SAndroid Build Coastguard Worker }
3093*da0073e9SAndroid Build Coastguard Worker }
3094*da0073e9SAndroid Build Coastguard Worker }
3095*da0073e9SAndroid Build Coastguard Worker
record_trace(TraceEntry::Action action,size_t addr,size_t size,cudaStream_t stream,c10::DeviceIndex device,std::shared_ptr<GatheredContext> context)3096*da0073e9SAndroid Build Coastguard Worker void record_trace(
3097*da0073e9SAndroid Build Coastguard Worker TraceEntry::Action action,
3098*da0073e9SAndroid Build Coastguard Worker size_t addr,
3099*da0073e9SAndroid Build Coastguard Worker size_t size,
3100*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
3101*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
3102*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context) {
3103*da0073e9SAndroid Build Coastguard Worker if (!record_history && trace_trackers_.empty())
3104*da0073e9SAndroid Build Coastguard Worker return;
3105*da0073e9SAndroid Build Coastguard Worker
3106*da0073e9SAndroid Build Coastguard Worker auto te = TraceEntry(
3107*da0073e9SAndroid Build Coastguard Worker action,
3108*da0073e9SAndroid Build Coastguard Worker device,
3109*da0073e9SAndroid Build Coastguard Worker addr,
3110*da0073e9SAndroid Build Coastguard Worker size,
3111*da0073e9SAndroid Build Coastguard Worker stream,
3112*da0073e9SAndroid Build Coastguard Worker getApproximateTime(),
3113*da0073e9SAndroid Build Coastguard Worker record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr);
3114*da0073e9SAndroid Build Coastguard Worker
3115*da0073e9SAndroid Build Coastguard Worker // Callbacks should not include any Pytorch call
3116*da0073e9SAndroid Build Coastguard Worker for (const auto& cb : trace_trackers_) {
3117*da0073e9SAndroid Build Coastguard Worker cb(te);
3118*da0073e9SAndroid Build Coastguard Worker }
3119*da0073e9SAndroid Build Coastguard Worker
3120*da0073e9SAndroid Build Coastguard Worker if (record_history) {
3121*da0073e9SAndroid Build Coastguard Worker alloc_buffer.insertEntries(te);
3122*da0073e9SAndroid Build Coastguard Worker }
3123*da0073e9SAndroid Build Coastguard Worker }
3124*da0073e9SAndroid Build Coastguard Worker };
3125*da0073e9SAndroid Build Coastguard Worker
3126*da0073e9SAndroid Build Coastguard Worker // Returns whether to force all allocations to bypass the caching allocator and
3127*da0073e9SAndroid Build Coastguard Worker // go straight to cudaMalloc. This setting is useful when debugging GPU memory
3128*da0073e9SAndroid Build Coastguard Worker // errors, since the caching allocator foils cuda-memcheck.
forceUncachedAllocator()3129*da0073e9SAndroid Build Coastguard Worker bool forceUncachedAllocator() {
3130*da0073e9SAndroid Build Coastguard Worker static bool force_uncached =
3131*da0073e9SAndroid Build Coastguard Worker getenv("PYTORCH_NO_CUDA_MEMORY_CACHING") != nullptr;
3132*da0073e9SAndroid Build Coastguard Worker return force_uncached;
3133*da0073e9SAndroid Build Coastguard Worker }
3134*da0073e9SAndroid Build Coastguard Worker
uncached_delete(void * ptr)3135*da0073e9SAndroid Build Coastguard Worker static void uncached_delete(void* ptr) {
3136*da0073e9SAndroid Build Coastguard Worker if (TORCH_SDT_IS_ENABLED(free)) {
3137*da0073e9SAndroid Build Coastguard Worker TORCH_SDT_WITH_SEMAPHORE(free, ptr);
3138*da0073e9SAndroid Build Coastguard Worker }
3139*da0073e9SAndroid Build Coastguard Worker
3140*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
3141*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
3142*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_memory_deallocation(
3143*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(ptr));
3144*da0073e9SAndroid Build Coastguard Worker }
3145*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaFree(ptr));
3146*da0073e9SAndroid Build Coastguard Worker }
3147*da0073e9SAndroid Build Coastguard Worker
3148*da0073e9SAndroid Build Coastguard Worker void local_raw_delete(void* ptr);
3149*da0073e9SAndroid Build Coastguard Worker
3150*da0073e9SAndroid Build Coastguard Worker class NativeCachingAllocator : public CUDAAllocator {
3151*da0073e9SAndroid Build Coastguard Worker private:
3152*da0073e9SAndroid Build Coastguard Worker // Shard allocation region to have independent mutexes to reduce contention.
3153*da0073e9SAndroid Build Coastguard Worker static constexpr size_t kNumMutexShard = 67;
3154*da0073e9SAndroid Build Coastguard Worker
3155*da0073e9SAndroid Build Coastguard Worker // TODO: use std::hardware_destructive_interference_size once available
3156*da0073e9SAndroid Build Coastguard Worker struct alignas(64) AlignedMutex {
3157*da0073e9SAndroid Build Coastguard Worker std::mutex m;
3158*da0073e9SAndroid Build Coastguard Worker };
3159*da0073e9SAndroid Build Coastguard Worker
3160*da0073e9SAndroid Build Coastguard Worker std::array<AlignedMutex, kNumMutexShard> mutex;
3161*da0073e9SAndroid Build Coastguard Worker
3162*da0073e9SAndroid Build Coastguard Worker // allocated blocks by device pointer
3163*da0073e9SAndroid Build Coastguard Worker std::array<ska::flat_hash_map<void*, Block*>, kNumMutexShard>
3164*da0073e9SAndroid Build Coastguard Worker allocated_blocks;
3165*da0073e9SAndroid Build Coastguard Worker
get_mutex_shard_id(void * ptr)3166*da0073e9SAndroid Build Coastguard Worker static size_t get_mutex_shard_id(void* ptr) {
3167*da0073e9SAndroid Build Coastguard Worker return twang_mix64((size_t)ptr) % kNumMutexShard;
3168*da0073e9SAndroid Build Coastguard Worker }
3169*da0073e9SAndroid Build Coastguard Worker
add_allocated_block(Block * block)3170*da0073e9SAndroid Build Coastguard Worker void add_allocated_block(Block* block) {
3171*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
3172*da0073e9SAndroid Build Coastguard Worker const auto mutex_shard_id = get_mutex_shard_id(block->ptr);
3173*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(mutex[mutex_shard_id].m);
3174*da0073e9SAndroid Build Coastguard Worker allocated_blocks[mutex_shard_id][block->ptr] = block;
3175*da0073e9SAndroid Build Coastguard Worker }
3176*da0073e9SAndroid Build Coastguard Worker
3177*da0073e9SAndroid Build Coastguard Worker // Variables by memory snapshot
3178*da0073e9SAndroid Build Coastguard Worker c10::ApproximateClockToUnixTimeConverter clock_converter;
3179*da0073e9SAndroid Build Coastguard Worker bool record_history = false;
3180*da0073e9SAndroid Build Coastguard Worker RingBuffer<AnnotationEntry> annotation_buffer;
3181*da0073e9SAndroid Build Coastguard Worker
3182*da0073e9SAndroid Build Coastguard Worker public:
3183*da0073e9SAndroid Build Coastguard Worker std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator;
3184*da0073e9SAndroid Build Coastguard Worker
get_allocated_block(void * ptr,bool remove=false)3185*da0073e9SAndroid Build Coastguard Worker Block* get_allocated_block(void* ptr, bool remove = false) {
3186*da0073e9SAndroid Build Coastguard Worker const auto mutex_shard_id = get_mutex_shard_id(ptr);
3187*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(mutex[mutex_shard_id].m);
3188*da0073e9SAndroid Build Coastguard Worker auto it = allocated_blocks[mutex_shard_id].find(ptr);
3189*da0073e9SAndroid Build Coastguard Worker if (it == allocated_blocks[mutex_shard_id].end()) {
3190*da0073e9SAndroid Build Coastguard Worker return nullptr;
3191*da0073e9SAndroid Build Coastguard Worker }
3192*da0073e9SAndroid Build Coastguard Worker Block* block = it->second;
3193*da0073e9SAndroid Build Coastguard Worker if (remove) {
3194*da0073e9SAndroid Build Coastguard Worker allocated_blocks[mutex_shard_id].erase(it);
3195*da0073e9SAndroid Build Coastguard Worker }
3196*da0073e9SAndroid Build Coastguard Worker return block;
3197*da0073e9SAndroid Build Coastguard Worker }
3198*da0073e9SAndroid Build Coastguard Worker
init(int device_count)3199*da0073e9SAndroid Build Coastguard Worker void init(int device_count) override {
3200*da0073e9SAndroid Build Coastguard Worker const auto size = static_cast<int64_t>(device_allocator.size());
3201*da0073e9SAndroid Build Coastguard Worker if (size < device_count) {
3202*da0073e9SAndroid Build Coastguard Worker device_allocator.resize(device_count);
3203*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(size, device_count)) {
3204*da0073e9SAndroid Build Coastguard Worker device_allocator[i] = std::make_unique<DeviceCachingAllocator>();
3205*da0073e9SAndroid Build Coastguard Worker }
3206*da0073e9SAndroid Build Coastguard Worker }
3207*da0073e9SAndroid Build Coastguard Worker }
3208*da0073e9SAndroid Build Coastguard Worker
initialized()3209*da0073e9SAndroid Build Coastguard Worker bool initialized() override {
3210*da0073e9SAndroid Build Coastguard Worker return !device_allocator.empty();
3211*da0073e9SAndroid Build Coastguard Worker }
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Worker /** allocates a block which is safe to use from the provided stream */
malloc(void ** devPtr,c10::DeviceIndex device,size_t size,cudaStream_t stream)3214*da0073e9SAndroid Build Coastguard Worker void malloc(
3215*da0073e9SAndroid Build Coastguard Worker void** devPtr,
3216*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
3217*da0073e9SAndroid Build Coastguard Worker size_t size,
3218*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream) {
3219*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
3220*da0073e9SAndroid Build Coastguard Worker 0 <= device && static_cast<size_t>(device) < device_allocator.size(),
3221*da0073e9SAndroid Build Coastguard Worker "Allocator not initialized for device ",
3222*da0073e9SAndroid Build Coastguard Worker device,
3223*da0073e9SAndroid Build Coastguard Worker ": did you call init?");
3224*da0073e9SAndroid Build Coastguard Worker Block* block = device_allocator[device]->malloc(device, size, stream);
3225*da0073e9SAndroid Build Coastguard Worker add_allocated_block(block);
3226*da0073e9SAndroid Build Coastguard Worker *devPtr = (void*)block->ptr;
3227*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
3228*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
3229*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_memory_allocation(
3230*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(*devPtr));
3231*da0073e9SAndroid Build Coastguard Worker }
3232*da0073e9SAndroid Build Coastguard Worker }
3233*da0073e9SAndroid Build Coastguard Worker
free(void * ptr)3234*da0073e9SAndroid Build Coastguard Worker void free(void* ptr) {
3235*da0073e9SAndroid Build Coastguard Worker if (!ptr) {
3236*da0073e9SAndroid Build Coastguard Worker return;
3237*da0073e9SAndroid Build Coastguard Worker }
3238*da0073e9SAndroid Build Coastguard Worker Block* block = get_allocated_block(ptr, true /* remove */);
3239*da0073e9SAndroid Build Coastguard Worker if (!block) {
3240*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "invalid device pointer: ", ptr);
3241*da0073e9SAndroid Build Coastguard Worker }
3242*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
3243*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
3244*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_memory_deallocation(
3245*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(block->ptr));
3246*da0073e9SAndroid Build Coastguard Worker }
3247*da0073e9SAndroid Build Coastguard Worker device_allocator[block->device]->free(block);
3248*da0073e9SAndroid Build Coastguard Worker }
3249*da0073e9SAndroid Build Coastguard Worker
setMemoryFraction(double fraction,c10::DeviceIndex device)3250*da0073e9SAndroid Build Coastguard Worker void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
3251*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
3252*da0073e9SAndroid Build Coastguard Worker 0 <= device && static_cast<size_t>(device) < device_allocator.size(),
3253*da0073e9SAndroid Build Coastguard Worker "Allocator not initialized for device ",
3254*da0073e9SAndroid Build Coastguard Worker device,
3255*da0073e9SAndroid Build Coastguard Worker ": did you call init?");
3256*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
3257*da0073e9SAndroid Build Coastguard Worker 0 <= fraction && fraction <= 1,
3258*da0073e9SAndroid Build Coastguard Worker "invalid fraction:",
3259*da0073e9SAndroid Build Coastguard Worker fraction,
3260*da0073e9SAndroid Build Coastguard Worker ". Please set within (0, 1).");
3261*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::SetDevice(device));
3262*da0073e9SAndroid Build Coastguard Worker device_allocator[device]->setMemoryFraction(fraction);
3263*da0073e9SAndroid Build Coastguard Worker }
3264*da0073e9SAndroid Build Coastguard Worker
recordHistory(bool enabled,CreateContextFn context_recorder,size_t alloc_buffer_max_entries,RecordContext when)3265*da0073e9SAndroid Build Coastguard Worker void recordHistory(
3266*da0073e9SAndroid Build Coastguard Worker bool enabled,
3267*da0073e9SAndroid Build Coastguard Worker CreateContextFn context_recorder,
3268*da0073e9SAndroid Build Coastguard Worker size_t alloc_buffer_max_entries,
3269*da0073e9SAndroid Build Coastguard Worker RecordContext when) override {
3270*da0073e9SAndroid Build Coastguard Worker record_history = enabled;
3271*da0073e9SAndroid Build Coastguard Worker annotation_buffer.setMaxEntries(alloc_buffer_max_entries);
3272*da0073e9SAndroid Build Coastguard Worker annotation_buffer.clear();
3273*da0073e9SAndroid Build Coastguard Worker for (auto& allocator : device_allocator) {
3274*da0073e9SAndroid Build Coastguard Worker allocator->recordHistory(
3275*da0073e9SAndroid Build Coastguard Worker enabled, context_recorder, alloc_buffer_max_entries, when);
3276*da0073e9SAndroid Build Coastguard Worker }
3277*da0073e9SAndroid Build Coastguard Worker }
3278*da0073e9SAndroid Build Coastguard Worker
recordAnnotation(const std::vector<std::pair<std::string,std::string>> & md)3279*da0073e9SAndroid Build Coastguard Worker void recordAnnotation(
3280*da0073e9SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, std::string>>& md) override {
3281*da0073e9SAndroid Build Coastguard Worker if (!record_history) {
3282*da0073e9SAndroid Build Coastguard Worker return;
3283*da0073e9SAndroid Build Coastguard Worker }
3284*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = 0;
3285*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
3286*da0073e9SAndroid Build Coastguard Worker auto ae = AnnotationEntry(
3287*da0073e9SAndroid Build Coastguard Worker /*device=*/device,
3288*da0073e9SAndroid Build Coastguard Worker /*time=*/getApproximateTime());
3289*da0073e9SAndroid Build Coastguard Worker for (const auto& md_pair : md) {
3290*da0073e9SAndroid Build Coastguard Worker ae.recordUserMetadata(md_pair.first, md_pair.second);
3291*da0073e9SAndroid Build Coastguard Worker }
3292*da0073e9SAndroid Build Coastguard Worker annotation_buffer.insertEntries(ae);
3293*da0073e9SAndroid Build Coastguard Worker }
3294*da0073e9SAndroid Build Coastguard Worker
isHistoryEnabled()3295*da0073e9SAndroid Build Coastguard Worker bool isHistoryEnabled() override {
3296*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = 0;
3297*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
3298*da0073e9SAndroid Build Coastguard Worker return device_allocator[device]->isHistoryEnabled();
3299*da0073e9SAndroid Build Coastguard Worker }
3300*da0073e9SAndroid Build Coastguard Worker
checkPoolLiveAllocations(c10::DeviceIndex device,MempoolId_t mempool_id,const std::unordered_set<void * > & expected_live_allocations)3301*da0073e9SAndroid Build Coastguard Worker bool checkPoolLiveAllocations(
3302*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
3303*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
3304*da0073e9SAndroid Build Coastguard Worker const std::unordered_set<void*>& expected_live_allocations) override {
3305*da0073e9SAndroid Build Coastguard Worker return device_allocator[device]->checkPoolLiveAllocations(
3306*da0073e9SAndroid Build Coastguard Worker mempool_id, expected_live_allocations);
3307*da0073e9SAndroid Build Coastguard Worker }
3308*da0073e9SAndroid Build Coastguard Worker
attachOutOfMemoryObserver(OutOfMemoryObserver observer)3309*da0073e9SAndroid Build Coastguard Worker void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
3310*da0073e9SAndroid Build Coastguard Worker for (auto& allocator : device_allocator) {
3311*da0073e9SAndroid Build Coastguard Worker allocator->attachOutOfMemoryObserver(observer);
3312*da0073e9SAndroid Build Coastguard Worker }
3313*da0073e9SAndroid Build Coastguard Worker }
3314*da0073e9SAndroid Build Coastguard Worker
attachAllocatorTraceTracker(AllocatorTraceTracker tracker)3315*da0073e9SAndroid Build Coastguard Worker void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override {
3316*da0073e9SAndroid Build Coastguard Worker for (auto& allocator : device_allocator) {
3317*da0073e9SAndroid Build Coastguard Worker allocator->attachAllocatorTraceTracker(tracker);
3318*da0073e9SAndroid Build Coastguard Worker }
3319*da0073e9SAndroid Build Coastguard Worker }
3320*da0073e9SAndroid Build Coastguard Worker
emptyCache()3321*da0073e9SAndroid Build Coastguard Worker void emptyCache() override {
3322*da0073e9SAndroid Build Coastguard Worker for (auto& da : device_allocator)
3323*da0073e9SAndroid Build Coastguard Worker da->emptyCache();
3324*da0073e9SAndroid Build Coastguard Worker }
3325*da0073e9SAndroid Build Coastguard Worker
getBaseAllocation(void * ptr,size_t * outSize)3326*da0073e9SAndroid Build Coastguard Worker void* getBaseAllocation(void* ptr, size_t* outSize) override {
3327*da0073e9SAndroid Build Coastguard Worker Block* block = get_allocated_block(ptr);
3328*da0073e9SAndroid Build Coastguard Worker if (!block) {
3329*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "invalid device pointer: ", ptr);
3330*da0073e9SAndroid Build Coastguard Worker }
3331*da0073e9SAndroid Build Coastguard Worker return device_allocator[block->device]->getBaseAllocation(block, outSize);
3332*da0073e9SAndroid Build Coastguard Worker }
3333*da0073e9SAndroid Build Coastguard Worker
shareIpcHandle(void * ptr)3334*da0073e9SAndroid Build Coastguard Worker ShareableHandle shareIpcHandle(void* ptr) override {
3335*da0073e9SAndroid Build Coastguard Worker Block* block = get_allocated_block(ptr);
3336*da0073e9SAndroid Build Coastguard Worker if (!block) {
3337*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "invalid device pointer: ", ptr);
3338*da0073e9SAndroid Build Coastguard Worker }
3339*da0073e9SAndroid Build Coastguard Worker return device_allocator[block->device]->shareIpcHandle(block);
3340*da0073e9SAndroid Build Coastguard Worker }
3341*da0073e9SAndroid Build Coastguard Worker
recordStream(const DataPtr & ptr,cuda::CUDAStream stream)3342*da0073e9SAndroid Build Coastguard Worker void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override {
3343*da0073e9SAndroid Build Coastguard Worker // Empty tensor's storage().data() might be a null ptr. As there is no
3344*da0073e9SAndroid Build Coastguard Worker // blocks associated with those tensors, it is fine to do nothing here.
3345*da0073e9SAndroid Build Coastguard Worker if (!ptr.get()) {
3346*da0073e9SAndroid Build Coastguard Worker return;
3347*da0073e9SAndroid Build Coastguard Worker }
3348*da0073e9SAndroid Build Coastguard Worker
3349*da0073e9SAndroid Build Coastguard Worker // If a tensor is not allocated by this instance, simply skip
3350*da0073e9SAndroid Build Coastguard Worker // This usually happens when CUDA tensors are shared across processes,
3351*da0073e9SAndroid Build Coastguard Worker // we have implemented reference counting based sharing mechanism to
3352*da0073e9SAndroid Build Coastguard Worker // guarantee tensors won't be accidentally freed by one process while
3353*da0073e9SAndroid Build Coastguard Worker // they are still being used in another
3354*da0073e9SAndroid Build Coastguard Worker if (ptr.get_deleter() != &local_raw_delete)
3355*da0073e9SAndroid Build Coastguard Worker return;
3356*da0073e9SAndroid Build Coastguard Worker
3357*da0073e9SAndroid Build Coastguard Worker Block* block = get_allocated_block(ptr.get());
3358*da0073e9SAndroid Build Coastguard Worker // block must not be null reaching here
3359*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found");
3360*da0073e9SAndroid Build Coastguard Worker device_allocator[block->device]->recordStream(block, stream);
3361*da0073e9SAndroid Build Coastguard Worker }
3362*da0073e9SAndroid Build Coastguard Worker
snapshot()3363*da0073e9SAndroid Build Coastguard Worker SnapshotInfo snapshot() override {
3364*da0073e9SAndroid Build Coastguard Worker // Set-up converter to convert timestamps from tsc to microseconds.
3365*da0073e9SAndroid Build Coastguard Worker auto tsc_to_ns = clock_converter.makeConverter();
3366*da0073e9SAndroid Build Coastguard Worker auto tsc_to_us = [=](approx_time_t t_approx) {
3367*da0073e9SAndroid Build Coastguard Worker return tsc_to_ns(t_approx) / 1000;
3368*da0073e9SAndroid Build Coastguard Worker };
3369*da0073e9SAndroid Build Coastguard Worker
3370*da0073e9SAndroid Build Coastguard Worker SnapshotInfo result;
3371*da0073e9SAndroid Build Coastguard Worker
3372*da0073e9SAndroid Build Coastguard Worker // Get AnnotationEntry list and convert the timestamps.
3373*da0073e9SAndroid Build Coastguard Worker annotation_buffer.getEntries(result.external_annotations);
3374*da0073e9SAndroid Build Coastguard Worker for (auto& ae : result.external_annotations) {
3375*da0073e9SAndroid Build Coastguard Worker ae.time_.t_ = tsc_to_us(ae.time_.approx_t_);
3376*da0073e9SAndroid Build Coastguard Worker }
3377*da0073e9SAndroid Build Coastguard Worker
3378*da0073e9SAndroid Build Coastguard Worker // Get the device_traces' TraceEntry lists.
3379*da0073e9SAndroid Build Coastguard Worker for (auto& da : device_allocator) {
3380*da0073e9SAndroid Build Coastguard Worker result.device_traces.emplace_back(da->trace(tsc_to_us));
3381*da0073e9SAndroid Build Coastguard Worker auto snap = da->snapshot();
3382*da0073e9SAndroid Build Coastguard Worker result.segments.insert(result.segments.end(), snap.begin(), snap.end());
3383*da0073e9SAndroid Build Coastguard Worker }
3384*da0073e9SAndroid Build Coastguard Worker
3385*da0073e9SAndroid Build Coastguard Worker auto& md = result.config_metadata;
3386*da0073e9SAndroid Build Coastguard Worker md.garbage_collection_threshold =
3387*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::garbage_collection_threshold();
3388*da0073e9SAndroid Build Coastguard Worker md.max_split_size = CUDAAllocatorConfig::max_split_size();
3389*da0073e9SAndroid Build Coastguard Worker md.pinned_num_register_threads =
3390*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::pinned_num_register_threads();
3391*da0073e9SAndroid Build Coastguard Worker md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
3392*da0073e9SAndroid Build Coastguard Worker md.release_lock_on_malloc =
3393*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::release_lock_on_cudamalloc();
3394*da0073e9SAndroid Build Coastguard Worker md.pinned_use_host_register =
3395*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::pinned_use_cuda_host_register();
3396*da0073e9SAndroid Build Coastguard Worker md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
3397*da0073e9SAndroid Build Coastguard Worker md.roundup_power2_divisions =
3398*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::roundup_power2_divisions();
3399*da0073e9SAndroid Build Coastguard Worker
3400*da0073e9SAndroid Build Coastguard Worker return result;
3401*da0073e9SAndroid Build Coastguard Worker }
3402*da0073e9SAndroid Build Coastguard Worker
getCheckpointState(c10::DeviceIndex device,MempoolId_t id)3403*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<AllocatorState> getCheckpointState(
3404*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
3405*da0073e9SAndroid Build Coastguard Worker MempoolId_t id) override {
3406*da0073e9SAndroid Build Coastguard Worker return device_allocator[device]->getCheckpointState(id);
3407*da0073e9SAndroid Build Coastguard Worker }
3408*da0073e9SAndroid Build Coastguard Worker
3409*da0073e9SAndroid Build Coastguard Worker /**
3410*da0073e9SAndroid Build Coastguard Worker * @brief Checkpoint the private pool state identified in `as` to its prior
3411*da0073e9SAndroid Build Coastguard Worker * state
3412*da0073e9SAndroid Build Coastguard Worker *
3413*da0073e9SAndroid Build Coastguard Worker * @param device - device of the pool to manipulate
3414*da0073e9SAndroid Build Coastguard Worker * @param as - allocator state
3415*da0073e9SAndroid Build Coastguard Worker * @param stale_live_storages - storages of tensors which are currently
3416*da0073e9SAndroid Build Coastguard Worker * allocated but which will be not be allocated after the checkpoint is set.
3417*da0073e9SAndroid Build Coastguard Worker * For these storages we will remove their deleter function.
3418*da0073e9SAndroid Build Coastguard Worker * @return CheckpointDelta - Freed Pointers and DataPtrs that contain deleter
3419*da0073e9SAndroid Build Coastguard Worker * functions for all allocated blocks in the new checkpoint state.
3420*da0073e9SAndroid Build Coastguard Worker */
setCheckpointPoolState(c10::DeviceIndex device,std::shared_ptr<AllocatorState> as)3421*da0073e9SAndroid Build Coastguard Worker CheckpointDelta setCheckpointPoolState(
3422*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
3423*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<AllocatorState> as) override {
3424*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<PrivatePoolState> pps =
3425*da0073e9SAndroid Build Coastguard Worker std::dynamic_pointer_cast<PrivatePoolState>(as);
3426*da0073e9SAndroid Build Coastguard Worker
3427*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(pps, "Expected PrivatePoolState");
3428*da0073e9SAndroid Build Coastguard Worker
3429*da0073e9SAndroid Build Coastguard Worker auto rr = device_allocator[device]->setCheckpointPoolState(*pps);
3430*da0073e9SAndroid Build Coastguard Worker
3431*da0073e9SAndroid Build Coastguard Worker CheckpointDelta cpd;
3432*da0073e9SAndroid Build Coastguard Worker for (void* ptr : rr.allocations_freed) {
3433*da0073e9SAndroid Build Coastguard Worker get_allocated_block(ptr, /*remove*/ true);
3434*da0073e9SAndroid Build Coastguard Worker cpd.ptrs_freed.push_back(ptr);
3435*da0073e9SAndroid Build Coastguard Worker }
3436*da0073e9SAndroid Build Coastguard Worker for (Block* block : rr.allocations_created) {
3437*da0073e9SAndroid Build Coastguard Worker add_allocated_block(block);
3438*da0073e9SAndroid Build Coastguard Worker cpd.dataptrs_allocd.emplace_back(
3439*da0073e9SAndroid Build Coastguard Worker block->ptr,
3440*da0073e9SAndroid Build Coastguard Worker block->ptr,
3441*da0073e9SAndroid Build Coastguard Worker &local_raw_delete,
3442*da0073e9SAndroid Build Coastguard Worker Device(DeviceType::CUDA, device));
3443*da0073e9SAndroid Build Coastguard Worker }
3444*da0073e9SAndroid Build Coastguard Worker
3445*da0073e9SAndroid Build Coastguard Worker return cpd;
3446*da0073e9SAndroid Build Coastguard Worker }
3447*da0073e9SAndroid Build Coastguard Worker
allocate(size_t size)3448*da0073e9SAndroid Build Coastguard Worker DataPtr allocate(size_t size) override {
3449*da0073e9SAndroid Build Coastguard Worker constexpr size_t one_exa_bytes = 1152921504606846976ULL;
3450*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_WITH(
3451*da0073e9SAndroid Build Coastguard Worker OutOfMemoryError,
3452*da0073e9SAndroid Build Coastguard Worker size < one_exa_bytes,
3453*da0073e9SAndroid Build Coastguard Worker "CUDA out of memory. Tried to allocate more than 1EB memory.");
3454*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = 0;
3455*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
3456*da0073e9SAndroid Build Coastguard Worker void* devPtr = nullptr;
3457*da0073e9SAndroid Build Coastguard Worker void (*deleteFunc)(void*) = &local_raw_delete;
3458*da0073e9SAndroid Build Coastguard Worker CUDAStream stream = cuda::getCurrentCUDAStream(device);
3459*da0073e9SAndroid Build Coastguard Worker
3460*da0073e9SAndroid Build Coastguard Worker if (forceUncachedAllocator()) {
3461*da0073e9SAndroid Build Coastguard Worker deleteFunc = &uncached_delete;
3462*da0073e9SAndroid Build Coastguard Worker
3463*da0073e9SAndroid Build Coastguard Worker // Deliberately don't use cudaMallocMaybeCapturing here, to force an error
3464*da0073e9SAndroid Build Coastguard Worker // if someone tries to use forceUncachedAllocator while capturing.
3465*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaMalloc(&devPtr, size));
3466*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
3467*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
3468*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_memory_allocation(
3469*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(devPtr));
3470*da0073e9SAndroid Build Coastguard Worker }
3471*da0073e9SAndroid Build Coastguard Worker } else {
3472*da0073e9SAndroid Build Coastguard Worker if (size != 0) {
3473*da0073e9SAndroid Build Coastguard Worker this->malloc(&devPtr, device, size, stream);
3474*da0073e9SAndroid Build Coastguard Worker }
3475*da0073e9SAndroid Build Coastguard Worker }
3476*da0073e9SAndroid Build Coastguard Worker
3477*da0073e9SAndroid Build Coastguard Worker if (size && TORCH_SDT_IS_ENABLED(malloc)) {
3478*da0073e9SAndroid Build Coastguard Worker TORCH_SDT_WITH_SEMAPHORE(malloc, devPtr, device, size, stream.id());
3479*da0073e9SAndroid Build Coastguard Worker }
3480*da0073e9SAndroid Build Coastguard Worker
3481*da0073e9SAndroid Build Coastguard Worker return {devPtr, devPtr, deleteFunc, Device(DeviceType::CUDA, device)};
3482*da0073e9SAndroid Build Coastguard Worker }
raw_deleter() const3483*da0073e9SAndroid Build Coastguard Worker DeleterFnPtr raw_deleter() const override {
3484*da0073e9SAndroid Build Coastguard Worker if (forceUncachedAllocator()) {
3485*da0073e9SAndroid Build Coastguard Worker return &uncached_delete;
3486*da0073e9SAndroid Build Coastguard Worker } else {
3487*da0073e9SAndroid Build Coastguard Worker return &local_raw_delete;
3488*da0073e9SAndroid Build Coastguard Worker }
3489*da0073e9SAndroid Build Coastguard Worker }
cacheInfo(c10::DeviceIndex device,size_t * largestBlock)3490*da0073e9SAndroid Build Coastguard Worker void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
3491*da0073e9SAndroid Build Coastguard Worker device_allocator[device]->cacheInfo(largestBlock);
3492*da0073e9SAndroid Build Coastguard Worker }
assertValidDevice(c10::DeviceIndex device)3493*da0073e9SAndroid Build Coastguard Worker void assertValidDevice(c10::DeviceIndex device) {
3494*da0073e9SAndroid Build Coastguard Worker const auto device_num = device_allocator.size();
3495*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
3496*da0073e9SAndroid Build Coastguard Worker 0 <= device && device < static_cast<int64_t>(device_num),
3497*da0073e9SAndroid Build Coastguard Worker "Invalid device argument ",
3498*da0073e9SAndroid Build Coastguard Worker device,
3499*da0073e9SAndroid Build Coastguard Worker ": did you call init?");
3500*da0073e9SAndroid Build Coastguard Worker }
3501*da0073e9SAndroid Build Coastguard Worker
getDeviceStats(c10::DeviceIndex device)3502*da0073e9SAndroid Build Coastguard Worker DeviceStats getDeviceStats(c10::DeviceIndex device) override {
3503*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
3504*da0073e9SAndroid Build Coastguard Worker return device_allocator[device]->getStats();
3505*da0073e9SAndroid Build Coastguard Worker }
3506*da0073e9SAndroid Build Coastguard Worker
resetAccumulatedStats(c10::DeviceIndex device)3507*da0073e9SAndroid Build Coastguard Worker void resetAccumulatedStats(c10::DeviceIndex device) override {
3508*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
3509*da0073e9SAndroid Build Coastguard Worker device_allocator[device]->resetAccumulatedStats();
3510*da0073e9SAndroid Build Coastguard Worker }
3511*da0073e9SAndroid Build Coastguard Worker
resetPeakStats(c10::DeviceIndex device)3512*da0073e9SAndroid Build Coastguard Worker void resetPeakStats(c10::DeviceIndex device) override {
3513*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
3514*da0073e9SAndroid Build Coastguard Worker device_allocator[device]->resetPeakStats();
3515*da0073e9SAndroid Build Coastguard Worker }
3516*da0073e9SAndroid Build Coastguard Worker // CUDAGraph interactions
beginAllocateToPool(c10::DeviceIndex device,MempoolId_t mempool_id,std::function<bool (cudaStream_t)> filter)3517*da0073e9SAndroid Build Coastguard Worker void beginAllocateToPool(
3518*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
3519*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
3520*da0073e9SAndroid Build Coastguard Worker std::function<bool(cudaStream_t)> filter) override {
3521*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
3522*da0073e9SAndroid Build Coastguard Worker device_allocator[device]->beginAllocateToPool(
3523*da0073e9SAndroid Build Coastguard Worker std::move(mempool_id), std::move(filter));
3524*da0073e9SAndroid Build Coastguard Worker }
3525*da0073e9SAndroid Build Coastguard Worker
endAllocateToPool(c10::DeviceIndex device,MempoolId_t mempool_id)3526*da0073e9SAndroid Build Coastguard Worker void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id)
3527*da0073e9SAndroid Build Coastguard Worker override {
3528*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
3529*da0073e9SAndroid Build Coastguard Worker device_allocator[device]->endAllocateToPool(mempool_id);
3530*da0073e9SAndroid Build Coastguard Worker }
3531*da0073e9SAndroid Build Coastguard Worker
releasePool(c10::DeviceIndex device,MempoolId_t mempool_id)3532*da0073e9SAndroid Build Coastguard Worker void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
3533*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
3534*da0073e9SAndroid Build Coastguard Worker device_allocator[device]->releasePool(std::move(mempool_id));
3535*da0073e9SAndroid Build Coastguard Worker }
3536*da0073e9SAndroid Build Coastguard Worker
raw_alloc(size_t nbytes)3537*da0073e9SAndroid Build Coastguard Worker void* raw_alloc(size_t nbytes) override {
3538*da0073e9SAndroid Build Coastguard Worker if (nbytes == 0) {
3539*da0073e9SAndroid Build Coastguard Worker return nullptr;
3540*da0073e9SAndroid Build Coastguard Worker }
3541*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = 0;
3542*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
3543*da0073e9SAndroid Build Coastguard Worker void* r = nullptr;
3544*da0073e9SAndroid Build Coastguard Worker malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
3545*da0073e9SAndroid Build Coastguard Worker return r;
3546*da0073e9SAndroid Build Coastguard Worker }
3547*da0073e9SAndroid Build Coastguard Worker
raw_alloc_with_stream(size_t nbytes,cudaStream_t stream)3548*da0073e9SAndroid Build Coastguard Worker void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override {
3549*da0073e9SAndroid Build Coastguard Worker if (nbytes == 0) {
3550*da0073e9SAndroid Build Coastguard Worker return nullptr;
3551*da0073e9SAndroid Build Coastguard Worker }
3552*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = 0;
3553*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
3554*da0073e9SAndroid Build Coastguard Worker void* r = nullptr;
3555*da0073e9SAndroid Build Coastguard Worker malloc(&r, device, nbytes, stream);
3556*da0073e9SAndroid Build Coastguard Worker return r;
3557*da0073e9SAndroid Build Coastguard Worker }
3558*da0073e9SAndroid Build Coastguard Worker
enablePeerAccess(c10::DeviceIndex dev,c10::DeviceIndex dev_to_access)3559*da0073e9SAndroid Build Coastguard Worker void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access)
3560*da0073e9SAndroid Build Coastguard Worker override {
3561*da0073e9SAndroid Build Coastguard Worker c10::cuda::CUDAGuard device_guard(dev);
3562*da0073e9SAndroid Build Coastguard Worker cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
3563*da0073e9SAndroid Build Coastguard Worker if (err == cudaErrorPeerAccessAlreadyEnabled) {
3564*da0073e9SAndroid Build Coastguard Worker // ignore and clear the error if access was already enabled
3565*da0073e9SAndroid Build Coastguard Worker (void)cudaGetLastError();
3566*da0073e9SAndroid Build Coastguard Worker } else {
3567*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(err);
3568*da0073e9SAndroid Build Coastguard Worker }
3569*da0073e9SAndroid Build Coastguard Worker device_allocator[dev_to_access]->addPeerAccess(dev);
3570*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(IpcMutex);
3571*da0073e9SAndroid Build Coastguard Worker for (auto& entry : ipcMemHandle_to_devptr) {
3572*da0073e9SAndroid Build Coastguard Worker if (entry.second.device_ == dev_to_access &&
3573*da0073e9SAndroid Build Coastguard Worker entry.second.expandable_segment_) {
3574*da0073e9SAndroid Build Coastguard Worker entry.second.expandable_segment_->addPeer(dev);
3575*da0073e9SAndroid Build Coastguard Worker }
3576*da0073e9SAndroid Build Coastguard Worker }
3577*da0073e9SAndroid Build Coastguard Worker }
3578*da0073e9SAndroid Build Coastguard Worker
memcpyAsync(void * dst,int dstDevice,const void * src,int srcDevice,size_t count,cudaStream_t stream,bool p2p_enabled)3579*da0073e9SAndroid Build Coastguard Worker cudaError_t memcpyAsync(
3580*da0073e9SAndroid Build Coastguard Worker void* dst,
3581*da0073e9SAndroid Build Coastguard Worker int dstDevice,
3582*da0073e9SAndroid Build Coastguard Worker const void* src,
3583*da0073e9SAndroid Build Coastguard Worker int srcDevice,
3584*da0073e9SAndroid Build Coastguard Worker size_t count,
3585*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
3586*da0073e9SAndroid Build Coastguard Worker bool p2p_enabled) override {
3587*da0073e9SAndroid Build Coastguard Worker if (p2p_enabled || // memcpy ok because memory is mapped in both devices
3588*da0073e9SAndroid Build Coastguard Worker srcDevice == dstDevice || // memcpy ok on a single device
3589*da0073e9SAndroid Build Coastguard Worker // memcpy ok because both dst and src must have come from cudaMalloc
3590*da0073e9SAndroid Build Coastguard Worker (!device_allocator[dstDevice]->hasAllocatedExpandableSegments() &&
3591*da0073e9SAndroid Build Coastguard Worker !device_allocator[srcDevice]->hasAllocatedExpandableSegments())) {
3592*da0073e9SAndroid Build Coastguard Worker return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream);
3593*da0073e9SAndroid Build Coastguard Worker }
3594*da0073e9SAndroid Build Coastguard Worker // when p2p is not enabled, only cudaMemcpyPeerAsync correctly handles
3595*da0073e9SAndroid Build Coastguard Worker // memory not allocated via cudaMalloc
3596*da0073e9SAndroid Build Coastguard Worker return cudaMemcpyPeerAsync(dst, dstDevice, src, srcDevice, count, stream);
3597*da0073e9SAndroid Build Coastguard Worker }
3598*da0073e9SAndroid Build Coastguard Worker
raw_delete(void * ptr)3599*da0073e9SAndroid Build Coastguard Worker void raw_delete(void* ptr) override {
3600*da0073e9SAndroid Build Coastguard Worker this->free(ptr);
3601*da0073e9SAndroid Build Coastguard Worker }
3602*da0073e9SAndroid Build Coastguard Worker
3603*da0073e9SAndroid Build Coastguard Worker // In CUDA IPC, sender sends a tensor to receiver via shareIPCHandle,
3604*da0073e9SAndroid Build Coastguard Worker // getIpcDevPtr is called by the receiving process to map the CUDA memory from
3605*da0073e9SAndroid Build Coastguard Worker // the sending process into its own address space.
3606*da0073e9SAndroid Build Coastguard Worker
3607*da0073e9SAndroid Build Coastguard Worker // When allocated with cudaMalloc we use the cudaIPCMemHandle_t APIs.
3608*da0073e9SAndroid Build Coastguard Worker // These APIs only allow sharing a big memory block associated with a
3609*da0073e9SAndroid Build Coastguard Worker // cudaIpcMemHandle_t and it can be opened only **once** per context per
3610*da0073e9SAndroid Build Coastguard Worker // process. There can be multiple types of storage in the same IPC mem block,
3611*da0073e9SAndroid Build Coastguard Worker // so we must cache the device ptr to construct typed storage as it comes.
3612*da0073e9SAndroid Build Coastguard Worker
3613*da0073e9SAndroid Build Coastguard Worker // When using cuMemCreate, via expandable segments, we use
3614*da0073e9SAndroid Build Coastguard Worker // cuMemExportToShareableHandle to create a file descriptor that can be sent
3615*da0073e9SAndroid Build Coastguard Worker // to the other process to sort the object. Then we recreate part of the
3616*da0073e9SAndroid Build Coastguard Worker // exandable segment necessary to load the allocation.
3617*da0073e9SAndroid Build Coastguard Worker
3618*da0073e9SAndroid Build Coastguard Worker // ipcMemHandle_to_devptr caches the mapping from shareable handle to
3619*da0073e9SAndroid Build Coastguard Worker // this process' memory mapping information for that share to ensure we do not
3620*da0073e9SAndroid Build Coastguard Worker // create it twice. When the shared_ptr is no longer in use we clean up the
3621*da0073e9SAndroid Build Coastguard Worker // cache.
3622*da0073e9SAndroid Build Coastguard Worker
3623*da0073e9SAndroid Build Coastguard Worker std::mutex IpcMutex;
3624*da0073e9SAndroid Build Coastguard Worker struct MemHandleCacheEntry {
MemHandleCacheEntryc10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::MemHandleCacheEntry3625*da0073e9SAndroid Build Coastguard Worker MemHandleCacheEntry(
3626*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
3627*da0073e9SAndroid Build Coastguard Worker std::string& handle,
3628*da0073e9SAndroid Build Coastguard Worker const DeviceCachingAllocator& allocator)
3629*da0073e9SAndroid Build Coastguard Worker : device_(device),
3630*da0073e9SAndroid Build Coastguard Worker expandable_segment_(nullptr),
3631*da0073e9SAndroid Build Coastguard Worker cuda_ipc_ptr_(nullptr) {
3632*da0073e9SAndroid Build Coastguard Worker int type = SHAREABLE_CUDA_MALLOC;
3633*da0073e9SAndroid Build Coastguard Worker std::istringstream ss(handle);
3634*da0073e9SAndroid Build Coastguard Worker if (handle.size() != CUDA_IPC_HANDLE_SIZE) {
3635*da0073e9SAndroid Build Coastguard Worker auto version = ss.get();
3636*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
3637*da0073e9SAndroid Build Coastguard Worker version <= SHAREABLE_HANDLE_VERSION,
3638*da0073e9SAndroid Build Coastguard Worker "received sharable handle from a future version of torch that this version does not know how to handle")
3639*da0073e9SAndroid Build Coastguard Worker type = ss.get();
3640*da0073e9SAndroid Build Coastguard Worker } // otherwise this is coming from an old pytorch where it has to be a raw
3641*da0073e9SAndroid Build Coastguard Worker // SHARABLE_CUDA_MALLOC
3642*da0073e9SAndroid Build Coastguard Worker if (type == SHAREABLE_CUDA_MALLOC) {
3643*da0073e9SAndroid Build Coastguard Worker cudaIpcMemHandle_t cuda_handle;
3644*da0073e9SAndroid Build Coastguard Worker ss.read((char*)&cuda_handle, CUDA_IPC_HANDLE_SIZE);
3645*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaIpcOpenMemHandle(
3646*da0073e9SAndroid Build Coastguard Worker &cuda_ipc_ptr_, cuda_handle, cudaIpcMemLazyEnablePeerAccess));
3647*da0073e9SAndroid Build Coastguard Worker } else if (type == SHAREABLE_CUDA_EXPANDABLE_SEGMENT) {
3648*da0073e9SAndroid Build Coastguard Worker expandable_segment_ =
3649*da0073e9SAndroid Build Coastguard Worker ExpandableSegment::fromShared(device, allocator.peers(), ss)
3650*da0073e9SAndroid Build Coastguard Worker .release();
3651*da0073e9SAndroid Build Coastguard Worker } else {
3652*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
3653*da0073e9SAndroid Build Coastguard Worker false, "unexpected or illformed shareable handle type");
3654*da0073e9SAndroid Build Coastguard Worker }
3655*da0073e9SAndroid Build Coastguard Worker }
3656*da0073e9SAndroid Build Coastguard Worker // this struct expects that clear is explicitly called to
3657*da0073e9SAndroid Build Coastguard Worker // free resources, because we only want this code running when
3658*da0073e9SAndroid Build Coastguard Worker // the shared pointer to this entry is destructed, not during
3659*da0073e9SAndroid Build Coastguard Worker // deinitialization when cuda may already have been shutdown.
3660*da0073e9SAndroid Build Coastguard Worker // This replicates the previous behavior of this map when it
3661*da0073e9SAndroid Build Coastguard Worker // stored raw cuda_ipc_ptr_ handles.
clearc10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::MemHandleCacheEntry3662*da0073e9SAndroid Build Coastguard Worker void clear() {
3663*da0073e9SAndroid Build Coastguard Worker if (cuda_ipc_ptr_) {
3664*da0073e9SAndroid Build Coastguard Worker cuda::CUDAGuard device_guard(device_);
3665*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaIpcCloseMemHandle(cuda_ipc_ptr_));
3666*da0073e9SAndroid Build Coastguard Worker cuda_ipc_ptr_ = nullptr;
3667*da0073e9SAndroid Build Coastguard Worker }
3668*da0073e9SAndroid Build Coastguard Worker if (expandable_segment_) {
3669*da0073e9SAndroid Build Coastguard Worker delete expandable_segment_;
3670*da0073e9SAndroid Build Coastguard Worker expandable_segment_ = nullptr;
3671*da0073e9SAndroid Build Coastguard Worker }
3672*da0073e9SAndroid Build Coastguard Worker }
ptrc10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::MemHandleCacheEntry3673*da0073e9SAndroid Build Coastguard Worker void* ptr() {
3674*da0073e9SAndroid Build Coastguard Worker if (cuda_ipc_ptr_) {
3675*da0073e9SAndroid Build Coastguard Worker return cuda_ipc_ptr_;
3676*da0073e9SAndroid Build Coastguard Worker } else {
3677*da0073e9SAndroid Build Coastguard Worker return expandable_segment_->ptr();
3678*da0073e9SAndroid Build Coastguard Worker }
3679*da0073e9SAndroid Build Coastguard Worker }
3680*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device_;
3681*da0073e9SAndroid Build Coastguard Worker ExpandableSegment* expandable_segment_;
3682*da0073e9SAndroid Build Coastguard Worker void* cuda_ipc_ptr_; // nullptr if expandable_segment_ is not null
3683*da0073e9SAndroid Build Coastguard Worker std::weak_ptr<void> wp_;
3684*da0073e9SAndroid Build Coastguard Worker };
3685*da0073e9SAndroid Build Coastguard Worker
3686*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<std::string, MemHandleCacheEntry> ipcMemHandle_to_devptr;
getIpcDevPtr(std::string handle)3687*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
3688*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(IpcMutex);
3689*da0073e9SAndroid Build Coastguard Worker
3690*da0073e9SAndroid Build Coastguard Worker auto iter = ipcMemHandle_to_devptr.find(handle);
3691*da0073e9SAndroid Build Coastguard Worker if (iter != ipcMemHandle_to_devptr.end()) {
3692*da0073e9SAndroid Build Coastguard Worker auto devptr = iter->second.wp_.lock();
3693*da0073e9SAndroid Build Coastguard Worker // the weak_ptr should always be valid because we delete the entry from
3694*da0073e9SAndroid Build Coastguard Worker // the cache when the shared_ptr is destructed, so we should never get
3695*da0073e9SAndroid Build Coastguard Worker // here.
3696*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(devptr, "entry in cache has missing shared_ptr");
3697*da0073e9SAndroid Build Coastguard Worker return devptr;
3698*da0073e9SAndroid Build Coastguard Worker }
3699*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex curr_device = 0;
3700*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
3701*da0073e9SAndroid Build Coastguard Worker auto inserted = ipcMemHandle_to_devptr.insert(
3702*da0073e9SAndroid Build Coastguard Worker iter,
3703*da0073e9SAndroid Build Coastguard Worker {handle,
3704*da0073e9SAndroid Build Coastguard Worker MemHandleCacheEntry(
3705*da0073e9SAndroid Build Coastguard Worker curr_device, handle, *device_allocator[curr_device])});
3706*da0073e9SAndroid Build Coastguard Worker auto sp = std::shared_ptr<void>(
3707*da0073e9SAndroid Build Coastguard Worker inserted->second.ptr(), [handle, this](void* ptr) {
3708*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> deleter_lock(IpcMutex);
3709*da0073e9SAndroid Build Coastguard Worker auto it = ipcMemHandle_to_devptr.find(handle);
3710*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(it != ipcMemHandle_to_devptr.end());
3711*da0073e9SAndroid Build Coastguard Worker it->second.clear();
3712*da0073e9SAndroid Build Coastguard Worker ipcMemHandle_to_devptr.erase(it);
3713*da0073e9SAndroid Build Coastguard Worker });
3714*da0073e9SAndroid Build Coastguard Worker inserted->second.wp_ = sp;
3715*da0073e9SAndroid Build Coastguard Worker return sp;
3716*da0073e9SAndroid Build Coastguard Worker }
3717*da0073e9SAndroid Build Coastguard Worker
name()3718*da0073e9SAndroid Build Coastguard Worker std::string name() override {
3719*da0073e9SAndroid Build Coastguard Worker return "native";
3720*da0073e9SAndroid Build Coastguard Worker }
copy_data(void * dest,const void * src,std::size_t count) const3721*da0073e9SAndroid Build Coastguard Worker void copy_data(void* dest, const void* src, std::size_t count) const final {
3722*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(
3723*da0073e9SAndroid Build Coastguard Worker cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
3724*da0073e9SAndroid Build Coastguard Worker }
3725*da0073e9SAndroid Build Coastguard Worker };
3726*da0073e9SAndroid Build Coastguard Worker
3727*da0073e9SAndroid Build Coastguard Worker NativeCachingAllocator allocator;
3728*da0073e9SAndroid Build Coastguard Worker
local_raw_delete(void * ptr)3729*da0073e9SAndroid Build Coastguard Worker void local_raw_delete(void* ptr) {
3730*da0073e9SAndroid Build Coastguard Worker if (TORCH_SDT_IS_ENABLED(free)) {
3731*da0073e9SAndroid Build Coastguard Worker TORCH_SDT_WITH_SEMAPHORE(free, ptr);
3732*da0073e9SAndroid Build Coastguard Worker }
3733*da0073e9SAndroid Build Coastguard Worker
3734*da0073e9SAndroid Build Coastguard Worker allocator.free(ptr);
3735*da0073e9SAndroid Build Coastguard Worker }
3736*da0073e9SAndroid Build Coastguard Worker
3737*da0073e9SAndroid Build Coastguard Worker } // namespace Native
3738*da0073e9SAndroid Build Coastguard Worker
3739*da0073e9SAndroid Build Coastguard Worker namespace CudaMallocAsync {
3740*da0073e9SAndroid Build Coastguard Worker // If this is put in its own header file, it gets incorrectly renamed in HIPify.
3741*da0073e9SAndroid Build Coastguard Worker CUDAAllocator* allocator();
3742*da0073e9SAndroid Build Coastguard Worker
3743*da0073e9SAndroid Build Coastguard Worker } // namespace CudaMallocAsync
3744*da0073e9SAndroid Build Coastguard Worker
3745*da0073e9SAndroid Build Coastguard Worker struct BackendStaticInitializer {
3746*da0073e9SAndroid Build Coastguard Worker // Parses env for backend at load time, duplicating some logic from
3747*da0073e9SAndroid Build Coastguard Worker // CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
3748*da0073e9SAndroid Build Coastguard Worker // runtime). Defers verbose exceptions and error checks, including Cuda
3749*da0073e9SAndroid Build Coastguard Worker // version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
3750*da0073e9SAndroid Build Coastguard Worker // works, maybe we should move all of CUDAAllocatorConfig here?
parseEnvForBackendc10::cuda::CUDACachingAllocator::BackendStaticInitializer3751*da0073e9SAndroid Build Coastguard Worker CUDAAllocator* parseEnvForBackend() {
3752*da0073e9SAndroid Build Coastguard Worker const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF");
3753*da0073e9SAndroid Build Coastguard Worker if (val != nullptr) {
3754*da0073e9SAndroid Build Coastguard Worker const std::string config(val);
3755*da0073e9SAndroid Build Coastguard Worker
3756*da0073e9SAndroid Build Coastguard Worker std::regex exp("[\\s,]+");
3757*da0073e9SAndroid Build Coastguard Worker std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
3758*da0073e9SAndroid Build Coastguard Worker std::sregex_token_iterator end;
3759*da0073e9SAndroid Build Coastguard Worker std::vector<std::string> options(it, end);
3760*da0073e9SAndroid Build Coastguard Worker
3761*da0073e9SAndroid Build Coastguard Worker for (auto option : options) {
3762*da0073e9SAndroid Build Coastguard Worker std::regex exp2("[:]+");
3763*da0073e9SAndroid Build Coastguard Worker std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
3764*da0073e9SAndroid Build Coastguard Worker std::sregex_token_iterator end2;
3765*da0073e9SAndroid Build Coastguard Worker std::vector<std::string> kv(it2, end2);
3766*da0073e9SAndroid Build Coastguard Worker if (kv.size() >= 2) {
3767*da0073e9SAndroid Build Coastguard Worker if (kv[0] == "backend") {
3768*da0073e9SAndroid Build Coastguard Worker if (kv[1] == "cudaMallocAsync")
3769*da0073e9SAndroid Build Coastguard Worker return CudaMallocAsync::allocator();
3770*da0073e9SAndroid Build Coastguard Worker if (kv[1] == "native")
3771*da0073e9SAndroid Build Coastguard Worker return &Native::allocator;
3772*da0073e9SAndroid Build Coastguard Worker }
3773*da0073e9SAndroid Build Coastguard Worker }
3774*da0073e9SAndroid Build Coastguard Worker }
3775*da0073e9SAndroid Build Coastguard Worker }
3776*da0073e9SAndroid Build Coastguard Worker return &Native::allocator;
3777*da0073e9SAndroid Build Coastguard Worker }
3778*da0073e9SAndroid Build Coastguard Worker
BackendStaticInitializerc10::cuda::CUDACachingAllocator::BackendStaticInitializer3779*da0073e9SAndroid Build Coastguard Worker BackendStaticInitializer() {
3780*da0073e9SAndroid Build Coastguard Worker auto r = parseEnvForBackend();
3781*da0073e9SAndroid Build Coastguard Worker allocator.store(r);
3782*da0073e9SAndroid Build Coastguard Worker }
3783*da0073e9SAndroid Build Coastguard Worker };
3784*da0073e9SAndroid Build Coastguard Worker
3785*da0073e9SAndroid Build Coastguard Worker std::atomic<CUDAAllocator*> allocator;
3786*da0073e9SAndroid Build Coastguard Worker BackendStaticInitializer backend_static_initializer;
3787*da0073e9SAndroid Build Coastguard Worker } // namespace cuda::CUDACachingAllocator
3788*da0073e9SAndroid Build Coastguard Worker } // namespace c10
3789*da0073e9SAndroid Build Coastguard Worker
3790*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
3791*da0073e9SAndroid Build Coastguard Worker
3792*da0073e9SAndroid Build Coastguard Worker // uid_ is incremented when a user creates a MemPool,
3793*da0073e9SAndroid Build Coastguard Worker // for example: using graph_pool_handle() or c10::cuda::MemPool().
3794*da0073e9SAndroid Build Coastguard Worker //
3795*da0073e9SAndroid Build Coastguard Worker // uuid_ is incremented when CUDAGraph creates a MemPool
3796*da0073e9SAndroid Build Coastguard Worker // as a result of a user not providing a pool.
3797*da0073e9SAndroid Build Coastguard Worker //
3798*da0073e9SAndroid Build Coastguard Worker // MempoolId_t of {0, 0} is used to denote when no MemPool has been
3799*da0073e9SAndroid Build Coastguard Worker // passed to a function, either by user or CUDAGraphs. For example,
3800*da0073e9SAndroid Build Coastguard Worker // default value of MempoolId_t for capture_begin function is {0, 0}.
3801*da0073e9SAndroid Build Coastguard Worker // That's why uid_ and uuid_ start at 1.
3802*da0073e9SAndroid Build Coastguard Worker std::atomic<CaptureId_t> MemPool::uid_{1};
3803*da0073e9SAndroid Build Coastguard Worker std::atomic<CaptureId_t> MemPool::uuid_{1};
3804*da0073e9SAndroid Build Coastguard Worker
MemPool(CUDACachingAllocator::CUDAAllocator * allocator,bool is_user_created)3805*da0073e9SAndroid Build Coastguard Worker MemPool::MemPool(
3806*da0073e9SAndroid Build Coastguard Worker CUDACachingAllocator::CUDAAllocator* allocator,
3807*da0073e9SAndroid Build Coastguard Worker bool is_user_created)
3808*da0073e9SAndroid Build Coastguard Worker : allocator_(allocator), is_user_created_(is_user_created) {
3809*da0073e9SAndroid Build Coastguard Worker if (is_user_created_) {
3810*da0073e9SAndroid Build Coastguard Worker id_ = {0, uid_++};
3811*da0073e9SAndroid Build Coastguard Worker } else {
3812*da0073e9SAndroid Build Coastguard Worker id_ = {uuid_++, 0};
3813*da0073e9SAndroid Build Coastguard Worker }
3814*da0073e9SAndroid Build Coastguard Worker }
3815*da0073e9SAndroid Build Coastguard Worker
id()3816*da0073e9SAndroid Build Coastguard Worker MempoolId_t MemPool::id() {
3817*da0073e9SAndroid Build Coastguard Worker return id_;
3818*da0073e9SAndroid Build Coastguard Worker }
3819*da0073e9SAndroid Build Coastguard Worker
allocator()3820*da0073e9SAndroid Build Coastguard Worker CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
3821*da0073e9SAndroid Build Coastguard Worker return allocator_;
3822*da0073e9SAndroid Build Coastguard Worker }
3823*da0073e9SAndroid Build Coastguard Worker
3824*da0073e9SAndroid Build Coastguard Worker // Note that active_mempool_ is a global variable here
3825*da0073e9SAndroid Build Coastguard Worker // and not inside MemPoolContext class, because in windows we
3826*da0073e9SAndroid Build Coastguard Worker // can't use __declspec(dllexport) and __declspec(thread)
3827*da0073e9SAndroid Build Coastguard Worker // together: https://stackoverflow.com/a/50967977
3828*da0073e9SAndroid Build Coastguard Worker static thread_local MemPool* active_mempool_ = nullptr;
3829*da0073e9SAndroid Build Coastguard Worker
MemPoolContext(MemPool * mempool)3830*da0073e9SAndroid Build Coastguard Worker MemPoolContext::MemPoolContext(MemPool* mempool)
3831*da0073e9SAndroid Build Coastguard Worker : prev_mempool_(active_mempool_) {
3832*da0073e9SAndroid Build Coastguard Worker active_mempool_ = mempool;
3833*da0073e9SAndroid Build Coastguard Worker }
3834*da0073e9SAndroid Build Coastguard Worker
~MemPoolContext()3835*da0073e9SAndroid Build Coastguard Worker MemPoolContext::~MemPoolContext() {
3836*da0073e9SAndroid Build Coastguard Worker active_mempool_ = prev_mempool_;
3837*da0073e9SAndroid Build Coastguard Worker }
3838*da0073e9SAndroid Build Coastguard Worker
getActiveMemPool()3839*da0073e9SAndroid Build Coastguard Worker MemPool* MemPoolContext::getActiveMemPool() {
3840*da0073e9SAndroid Build Coastguard Worker return active_mempool_;
3841*da0073e9SAndroid Build Coastguard Worker }
3842*da0073e9SAndroid Build Coastguard Worker
3843*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
3844