xref: /aosp_15_r20/external/pytorch/c10/xpu/XPUCachingAllocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/flat_hash_map.h>
2 #include <c10/util/irange.h>
3 #include <c10/xpu/XPUCachingAllocator.h>
4 
5 #include <deque>
6 #include <mutex>
7 #include <set>
8 #include <vector>
9 
10 namespace c10::xpu::XPUCachingAllocator {
11 
12 using namespace c10::CachingDeviceAllocator;
13 
14 // newly allocated memory with 512-byte alignment.
15 constexpr size_t kDeviceAlignment = 512;
16 // all sizes are rounded to at least 512 bytes
17 constexpr size_t kMinBlockSize = 512;
18 // largest "small" allocation is 1 MiB
19 constexpr size_t kSmallSize = 1048576;
20 // "small" allocations are packed in 2 MiB blocks
21 constexpr size_t kSmallBuffer = 2097152;
22 // "large" allocations may be packed in 20 MiB blocks
23 constexpr size_t kLargeBuffer = 20971520;
24 // allocations between 1 and 10 MiB may use kLargeBuffer
25 constexpr size_t kMinLargeAlloc = 10485760;
26 // round up large allocations to 2 MiB
27 constexpr size_t kRoundLarge = 2097152;
28 
29 namespace {
30 using stream_set = ska::flat_hash_set<xpu::XPUStream>;
31 
32 struct Block;
33 typedef bool (*Comparison)(const Block*, const Block*);
34 bool BlockComparatorSize(const Block* a, const Block* b);
35 
36 struct BlockPool {
BlockPoolc10::xpu::XPUCachingAllocator::__anon8131195a0111::BlockPool37   BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
38   std::set<Block*, Comparison> blocks;
39   const bool is_small;
40 };
41 
42 struct Block {
43   DeviceIndex device;
44   sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
45   stream_set stream_uses; // streams on which the block was used
46   size_t size; // block size in bytes
47   size_t requested_size; // memory originally requested
48   BlockPool* pool{nullptr}; // owning memory pool
49   void* ptr{nullptr}; // memory address
50   bool allocated{false}; // in-use flag
51   Block* prev{nullptr}; // prev block if split from a larger allocation
52   Block* next{nullptr}; // next block if split from a larger allocation
53   int event_count{0}; // number of outstanding XPU events
54 
Blockc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block55   Block(
56       DeviceIndex device,
57       sycl::queue* queue,
58       size_t size,
59       BlockPool* pool,
60       void* ptr)
61       : device(device),
62         queue(queue),
63         stream_uses(),
64         size(size),
65         requested_size(0),
66         pool(pool),
67         ptr(ptr) {}
68 
69   // constructor for search key
Blockc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block70   Block(DeviceIndex device, sycl::queue* queue, size_t size)
71       : device(device),
72         queue(queue),
73         stream_uses(),
74         size(size),
75         requested_size(0) {}
76 
is_splitc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block77   bool is_split() const {
78     return (prev != nullptr) || (next != nullptr);
79   }
80 };
81 
BlockComparatorSize(const Block * a,const Block * b)82 bool BlockComparatorSize(const Block* a, const Block* b) {
83   if (a->queue != b->queue) {
84     return reinterpret_cast<uintptr_t>(a->queue) <
85         reinterpret_cast<uintptr_t>(b->queue);
86   }
87   if (a->size != b->size) {
88     return a->size < b->size;
89   }
90   return reinterpret_cast<uintptr_t>(a->ptr) <
91       reinterpret_cast<uintptr_t>(b->ptr);
92 }
93 
94 struct AllocParams {
AllocParamsc10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams95   AllocParams(
96       DeviceIndex device,
97       size_t size,
98       sycl::queue* queue,
99       BlockPool* pool,
100       size_t alloc_size)
101       : search_key(device, queue, size),
102         pool(pool),
103         alloc_size(alloc_size),
104         block(nullptr) {}
105 
devicec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams106   DeviceIndex device() const {
107     return search_key.device;
108   }
109 
queuec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams110   sycl::queue* queue() const {
111     return search_key.queue;
112   }
113 
sizec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams114   size_t size() const {
115     return search_key.size;
116   }
117 
118   Block search_key;
119   BlockPool* pool;
120   size_t alloc_size;
121   Block* block;
122   StatTypes stat_types = {};
123 };
124 
125 } // anonymous namespace
126 
127 class DeviceCachingAllocator {
128  private:
129   mutable std::recursive_mutex mutex;
130   DeviceStats stats;
131   BlockPool large_blocks; // unallocated cached blocks larger than 1 MB
132   BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller
133   ska::flat_hash_set<Block*> active_blocks; // allocated or in use by a stream
134   ska::flat_hash_map<xpu::XPUStream, std::deque<std::pair<sycl::event, Block*>>>
135       xpu_events;
136   DeviceIndex device_index;
137 
try_merge_blocks(Block * dst,Block * src,BlockPool & pool)138   size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
139     if (!src || src->allocated || src->event_count > 0 ||
140         !src->stream_uses.empty()) {
141       return 0;
142     }
143 
144     TORCH_INTERNAL_ASSERT(dst->is_split() && src->is_split());
145     if (dst->prev == src) { // [src dst]
146       dst->ptr = src->ptr;
147       dst->prev = src->prev;
148       if (dst->prev) {
149         dst->prev->next = dst;
150       }
151     } else { // [dst src]
152       dst->next = src->next;
153       if (dst->next) {
154         dst->next->prev = dst;
155       }
156     }
157     const size_t subsumed_size = src->size;
158     dst->size += subsumed_size;
159     auto erased = pool.blocks.erase(src);
160     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
161     delete src;
162 
163     return subsumed_size;
164   }
165 
free_block(Block * block)166   void free_block(Block* block) {
167     TORCH_INTERNAL_ASSERT(
168         !block->allocated && block->event_count == 0 &&
169         block->stream_uses.empty());
170 
171     size_t original_block_size = block->size;
172     size_t requested_size = block->requested_size;
173     auto& pool = *block->pool;
174     const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
175     for (Block* merge_candidate : merge_candidates) {
176       try_merge_blocks(block, merge_candidate, pool);
177     }
178 
179     active_blocks.erase(block);
180     bool inserted = pool.blocks.insert(block).second;
181     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
182 
183     StatTypes stat_types = get_stat_types_for_pool(pool);
184     for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
185       stats.active_bytes[stat_type].decrease(original_block_size);
186       stats.requested_bytes[stat_type].decrease(requested_size);
187     });
188   }
189 
process_events()190   void process_events() {
191     using namespace sycl::info;
192     for (auto it = xpu_events.begin(); it != xpu_events.end();) {
193       while (!it->second.empty()) {
194         auto& e = it->second.front();
195         auto event = e.first;
196         auto* block = e.second;
197         if (event.get_info<event::command_execution_status>() !=
198             event_command_status::complete) {
199           break;
200         }
201         block->event_count--;
202         if (block->event_count == 0) {
203           free_block(block);
204         }
205         it->second.pop_front();
206       }
207 
208       if (it->second.empty()) {
209         it = xpu_events.erase(it);
210       } else {
211         it++;
212       }
213     }
214   }
215 
round_size(size_t size)216   static size_t round_size(size_t size) {
217     if (size < kMinBlockSize) {
218       return kMinBlockSize;
219     } else {
220       return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
221     }
222   }
223 
get_allocation_size(size_t size)224   static size_t get_allocation_size(size_t size) {
225     if (size <= kSmallSize) {
226       return kSmallBuffer;
227     } else if (size < kMinLargeAlloc) {
228       return kLargeBuffer;
229     } else {
230       return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
231     }
232   }
233 
get_pool(size_t size)234   BlockPool& get_pool(size_t size) {
235     if (size < kSmallSize) {
236       return small_blocks;
237     } else {
238       return large_blocks;
239     }
240   }
241 
get_free_block(AllocParams & p)242   bool get_free_block(AllocParams& p) {
243     BlockPool& pool = *p.pool;
244     auto it = pool.blocks.lower_bound(&p.search_key);
245     if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
246       return false;
247     }
248     p.block = *it;
249     pool.blocks.erase(it);
250     return true;
251   }
252 
alloc_block(AllocParams & p)253   bool alloc_block(AllocParams& p) {
254     auto size = p.alloc_size;
255     auto device = p.device();
256     void* ptr = sycl::aligned_alloc_device(
257         kDeviceAlignment,
258         size,
259         xpu::get_raw_device(device),
260         xpu::get_device_context());
261     if (!ptr) {
262       return false;
263     }
264     p.block = new Block(device, p.queue(), size, p.pool, ptr);
265     for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
266       stats.reserved_bytes[stat_type].increase(size);
267     });
268     return true;
269   }
270 
synchronize_and_free_events()271   void synchronize_and_free_events() {
272     for (auto& xe : xpu_events) {
273       for (auto& e : xe.second) {
274         auto event = e.first;
275         auto* block = e.second;
276         event.wait();
277         block->event_count--;
278         if (block->event_count == 0) {
279           free_block(block);
280         }
281       }
282     }
283     xpu_events.clear();
284   }
285 
release_block(Block * block)286   void release_block(Block* block) {
287     /*
288      * Note [Safe to Free Blocks on BlockPool]
289      *
290      * Callers must ensure that all accesses to the block, whose raw pointer is
291      * allocated by SYCL APIs, have been completed before invoking sycl::free.
292      *
293      * We have to do a device-level synchronization before free these blocks to
294      * guarantee that all kernels can access to the blocks have finished.
295      */
296     sycl::free(block->ptr, xpu::get_device_context());
297     auto* pool = block->pool;
298     pool->blocks.erase(block);
299 
300     StatTypes stat_types = get_stat_types_for_pool(*pool);
301     for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
302       stats.reserved_bytes[stat_type].decrease(block->size);
303     });
304 
305     delete block;
306   }
307 
release_blocks(BlockPool & pool)308   void release_blocks(BlockPool& pool) {
309     auto it = pool.blocks.begin();
310     while (it != pool.blocks.end()) {
311       Block* block = *it;
312       ++it;
313       if (!block->prev && !block->next) {
314         release_block(block);
315       }
316     }
317   }
318 
release_cached_blocks()319   bool release_cached_blocks() {
320     synchronize_and_free_events();
321     // See Note [Safe to Free Blocks on BlockPool]
322     c10::xpu::syncStreamsOnDevice(device_index);
323 
324     release_blocks(large_blocks);
325     release_blocks(small_blocks);
326     return true;
327   }
328 
should_split(const Block * block,size_t size)329   bool should_split(const Block* block, size_t size) {
330     size_t remaining = block->size - size;
331     if (block->pool->is_small) {
332       return remaining >= kMinBlockSize;
333     } else {
334       return remaining > kSmallSize;
335     }
336   }
337 
get_stat_types_for_pool(const BlockPool & pool)338   StatTypes get_stat_types_for_pool(const BlockPool& pool) {
339     StatTypes stat_types = {};
340     stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
341     stat_types[static_cast<size_t>(
342         pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
343     return stat_types;
344   }
345 
alloc_found_block(AllocParams params,size_t orig_size,bool split_remainder)346   Block* alloc_found_block(
347       AllocParams params,
348       size_t orig_size,
349       bool split_remainder) {
350     auto size = params.size();
351     auto device = params.device();
352     BlockPool* pool = params.pool;
353     sycl::queue* queue = params.queue();
354 
355     TORCH_INTERNAL_ASSERT(
356         params.block != nullptr && params.block->ptr != nullptr);
357     Block* block = params.block;
358     Block* remaining = nullptr;
359 
360     if (split_remainder) {
361       remaining = block;
362 
363       block = new Block(device, queue, size, pool, block->ptr);
364       block->prev = remaining->prev;
365       if (block->prev) {
366         block->prev->next = block;
367       }
368       block->next = remaining;
369 
370       remaining->prev = block;
371       remaining->ptr = static_cast<char*>(remaining->ptr) + size;
372       remaining->size -= size;
373       bool inserted = pool->blocks.insert(remaining).second;
374       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
375     }
376 
377     block->allocated = true;
378     block->requested_size = orig_size;
379     bool inserted = active_blocks.insert(block).second;
380     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted)
381 
382     for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
383       stats.allocated_bytes[stat_type].increase(block->size);
384       stats.active_bytes[stat_type].increase(block->size);
385       stats.requested_bytes[stat_type].increase(block->requested_size);
386     });
387 
388     return block;
389   }
390 
insert_events(Block * block)391   void insert_events(Block* block) {
392     stream_set streams(std::move(block->stream_uses));
393     TORCH_INTERNAL_ASSERT(block->stream_uses.empty());
394     for (auto& stream : streams) {
395       block->event_count++;
396       xpu_events[stream].emplace_back(
397           stream.queue().ext_oneapi_submit_barrier(), block);
398     }
399   }
400 
401  public:
DeviceCachingAllocator(DeviceIndex device_index)402   DeviceCachingAllocator(DeviceIndex device_index)
403       : large_blocks(/* small */ false),
404         small_blocks(/* small */ true),
405         device_index(device_index) {}
406 
malloc(DeviceIndex device,size_t orig_size,sycl::queue & queue)407   Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
408     std::scoped_lock<std::recursive_mutex> lock(mutex);
409     process_events();
410     size_t size = round_size(orig_size);
411     auto& pool = get_pool(size);
412     const size_t alloc_size = get_allocation_size(size);
413     AllocParams params(device, size, &queue, &pool, alloc_size);
414     params.stat_types = get_stat_types_for_pool(pool);
415 
416     // First, try to get a block from the existing pool.
417     bool block_found = get_free_block(params);
418     // Can't reuse an existing block, try to get a new one.
419     if (!block_found) {
420       block_found = alloc_block(params) ||
421           (release_cached_blocks() && alloc_block(params));
422     }
423     if (!block_found) {
424       c10::xpu::DeviceProp device_prop;
425       c10::xpu::get_device_properties(&device_prop, device);
426       auto device_total = device_prop.global_mem_size;
427       auto allocated_bytes =
428           stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
429               .current;
430       auto reserved_bytes =
431           stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
432               .current;
433       TORCH_CHECK_WITH(
434           OutOfMemoryError,
435           false,
436           "XPU out of memory. Tried to allocate ",
437           format_size(alloc_size),
438           ". GPU ",
439           static_cast<int>(device),
440           " has a total capacity of ",
441           format_size(device_total),
442           ". Of the allocated memory ",
443           format_size(allocated_bytes),
444           " is allocated by PyTorch, and ",
445           format_size(reserved_bytes - allocated_bytes),
446           " is reserved by PyTorch but unallocated.",
447           " Please use `empty_cache` to release all unoccupied cached memory.");
448     }
449     bool split_remainder = should_split(params.block, params.size());
450     return alloc_found_block(std::move(params), orig_size, split_remainder);
451   }
452 
free(Block * block)453   void free(Block* block) {
454     std::scoped_lock<std::recursive_mutex> lock(mutex);
455     block->allocated = false;
456 
457     StatTypes stat_types = get_stat_types_for_pool(*block->pool);
458     for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
459       stats.allocated_bytes[stat_type].decrease(block->size);
460     });
461 
462     if (!block->stream_uses.empty()) {
463       insert_events(block);
464     } else {
465       free_block(block);
466     }
467   }
468 
recordStream(Block * block,xpu::XPUStream stream)469   void recordStream(Block* block, xpu::XPUStream stream) {
470     std::scoped_lock<std::recursive_mutex> lock(mutex);
471     if (stream.queue() == *block->queue) {
472       return;
473     }
474     block->stream_uses.insert(stream);
475   }
476 
emptyCache()477   void emptyCache() {
478     std::scoped_lock<std::recursive_mutex> lock(mutex);
479     release_cached_blocks();
480   }
481 
getStats()482   DeviceStats getStats() {
483     std::scoped_lock<std::recursive_mutex> lock(mutex);
484     return stats;
485   }
486 
resetAccumulatedStats()487   void resetAccumulatedStats() {
488     std::scoped_lock<std::recursive_mutex> lock(mutex);
489 
490     for (const auto statType :
491          c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
492       stats.allocated_bytes[statType].reset_accumulated();
493       stats.reserved_bytes[statType].reset_accumulated();
494       stats.active_bytes[statType].reset_accumulated();
495       stats.requested_bytes[statType].reset_accumulated();
496     }
497   }
498 
resetPeakStats()499   void resetPeakStats() {
500     std::scoped_lock<std::recursive_mutex> lock(mutex);
501 
502     for (const auto statType :
503          c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
504       stats.allocated_bytes[statType].reset_peak();
505       stats.reserved_bytes[statType].reset_peak();
506       stats.active_bytes[statType].reset_peak();
507       stats.requested_bytes[statType].reset_peak();
508     }
509   }
510 };
511 
512 void local_raw_delete(void* ptr);
513 
514 class XPUAllocator : public Allocator {
515  private:
516   std::mutex mutex;
517   ska::flat_hash_map<void*, Block*> allocated_blocks;
518 
add_allocated_block(Block * block)519   void add_allocated_block(Block* block) {
520     std::lock_guard<std::mutex> lock(mutex);
521     allocated_blocks[block->ptr] = block;
522   }
523 
get_allocated_block(void * ptr,bool remove=false)524   Block* get_allocated_block(void* ptr, bool remove = false) {
525     std::scoped_lock<std::mutex> lock(mutex);
526     auto it = allocated_blocks.find(ptr);
527     if (it == allocated_blocks.end()) {
528       return nullptr;
529     }
530     Block* block = it->second;
531     if (remove) {
532       allocated_blocks.erase(it);
533     }
534     return block;
535   }
536 
537  public:
538   std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
539 
init(DeviceIndex device_count)540   void init(DeviceIndex device_count) {
541     const auto size = static_cast<DeviceIndex>(device_allocators.size());
542     if (size < device_count) {
543       device_allocators.resize(device_count);
544       for (const auto i : c10::irange(size, device_count)) {
545         device_allocators[i] = std::make_unique<DeviceCachingAllocator>(i);
546       }
547     }
548   }
549 
malloc(void ** devPtr,DeviceIndex device,size_t size,sycl::queue & queue)550   void malloc(
551       void** devPtr,
552       DeviceIndex device,
553       size_t size,
554       sycl::queue& queue) {
555     TORCH_INTERNAL_ASSERT(
556         0 <= device && static_cast<size_t>(device) < device_allocators.size(),
557         "Allocator not initialized for device ",
558         static_cast<int16_t>(device),
559         ": did you call init?");
560     Block* block = device_allocators[device]->malloc(device, size, queue);
561     add_allocated_block(block);
562     *devPtr = block->ptr;
563     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
564     if (C10_UNLIKELY(interp)) {
565       (*interp)->trace_gpu_memory_allocation(
566           c10::kXPU, reinterpret_cast<uintptr_t>(*devPtr));
567     }
568   }
569 
free(void * ptr)570   void free(void* ptr) {
571     if (!ptr) {
572       return;
573     }
574     Block* block = get_allocated_block(ptr, /* remove */ true);
575     TORCH_CHECK(block, "invalid device pointer: ", ptr);
576     device_allocators[block->device]->free(block);
577     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
578     if (C10_UNLIKELY(interp)) {
579       (*interp)->trace_gpu_memory_deallocation(
580           c10::kXPU, reinterpret_cast<uintptr_t>(block->ptr));
581     }
582   }
583 
emptyCache()584   void emptyCache() {
585     for (auto& da : device_allocators) {
586       da->emptyCache();
587     }
588   }
589 
recordStream(const DataPtr & ptr,XPUStream stream)590   void recordStream(const DataPtr& ptr, XPUStream stream) {
591     if (!ptr.get()) {
592       return;
593     }
594     if (ptr.get_deleter() != &local_raw_delete) {
595       return;
596     }
597 
598     Block* block = get_allocated_block(ptr.get());
599     TORCH_CHECK(block, "No allocated block can be found.");
600     device_allocators[block->device]->recordStream(block, stream);
601   }
602 
allocate(size_t size)603   DataPtr allocate(size_t size) override {
604     auto device = c10::xpu::current_device();
605     void* r = nullptr;
606     if (size != 0) {
607       this->malloc(&r, device, size, xpu::getCurrentXPUStream(device));
608     }
609     return {r, r, &local_raw_delete, Device(DeviceType::XPU, device)};
610   }
611 
raw_deleter() const612   DeleterFnPtr raw_deleter() const override {
613     return &local_raw_delete;
614   }
615 
raw_alloc(size_t size)616   void* raw_alloc(size_t size) {
617     if (size == 0) {
618       return nullptr;
619     }
620     auto device = c10::xpu::current_device();
621     void* r = nullptr;
622     malloc(&r, device, size, xpu::getCurrentXPUStream(device));
623     return r;
624   }
625 
raw_alloc_with_stream(size_t size,XPUStream stream)626   void* raw_alloc_with_stream(size_t size, XPUStream stream) {
627     if (size == 0) {
628       return nullptr;
629     }
630     auto device = c10::xpu::current_device();
631     void* r = nullptr;
632     malloc(&r, device, size, stream);
633     return r;
634   }
635 
raw_delete(void * ptr)636   void raw_delete(void* ptr) {
637     this->free(ptr);
638   }
639 
copy_data(void * dest,const void * src,std::size_t count) const640   void copy_data(void* dest, const void* src, std::size_t count) const final {
641     xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
642   }
643 
assertValidDevice(DeviceIndex device)644   void assertValidDevice(DeviceIndex device) {
645     const auto device_num = device_allocators.size();
646     TORCH_CHECK(
647         0 <= device && device < static_cast<int64_t>(device_num),
648         "Invalid device argument ",
649         device,
650         ": did you call init?");
651   }
652 
getDeviceStats(DeviceIndex device)653   DeviceStats getDeviceStats(DeviceIndex device) {
654     assertValidDevice(device);
655     return device_allocators[device]->getStats();
656   }
657 
resetPeakStats(DeviceIndex device)658   void resetPeakStats(DeviceIndex device) {
659     assertValidDevice(device);
660     device_allocators[device]->resetPeakStats();
661   }
662 
resetAccumulatedStats(DeviceIndex device)663   void resetAccumulatedStats(DeviceIndex device) {
664     assertValidDevice(device);
665     device_allocators[device]->resetAccumulatedStats();
666   }
667 };
668 
669 static XPUAllocator allocator;
670 
local_raw_delete(void * ptr)671 void local_raw_delete(void* ptr) {
672   allocator.free(ptr);
673 }
674 
get()675 Allocator* get() {
676   return &allocator;
677 }
678 
init(DeviceIndex device_count)679 void init(DeviceIndex device_count) {
680   return allocator.init(device_count);
681 }
682 
emptyCache()683 void emptyCache() {
684   return allocator.emptyCache();
685 }
686 
resetPeakStats(DeviceIndex device)687 void resetPeakStats(DeviceIndex device) {
688   return allocator.resetPeakStats(device);
689 }
690 
resetAccumulatedStats(DeviceIndex device)691 void resetAccumulatedStats(DeviceIndex device) {
692   return allocator.resetAccumulatedStats(device);
693 }
694 
getDeviceStats(DeviceIndex device)695 DeviceStats getDeviceStats(DeviceIndex device) {
696   return allocator.getDeviceStats(device);
697 }
698 
raw_alloc(size_t size)699 void* raw_alloc(size_t size) {
700   return allocator.raw_alloc(size);
701 }
702 
raw_delete(void * ptr)703 void raw_delete(void* ptr) {
704   return allocator.raw_delete(ptr);
705 }
706 
recordStream(const DataPtr & dataPtr,XPUStream stream)707 void recordStream(const DataPtr& dataPtr, XPUStream stream) {
708   return allocator.recordStream(dataPtr, stream);
709 }
710 
711 REGISTER_ALLOCATOR(kXPU, &allocator)
712 
713 } // namespace c10::xpu::XPUCachingAllocator
714