1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/flat_hash_map.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUCachingAllocator.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker #include <deque>
6*da0073e9SAndroid Build Coastguard Worker #include <mutex>
7*da0073e9SAndroid Build Coastguard Worker #include <set>
8*da0073e9SAndroid Build Coastguard Worker #include <vector>
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu::XPUCachingAllocator {
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker using namespace c10::CachingDeviceAllocator;
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker // newly allocated memory with 512-byte alignment.
15*da0073e9SAndroid Build Coastguard Worker constexpr size_t kDeviceAlignment = 512;
16*da0073e9SAndroid Build Coastguard Worker // all sizes are rounded to at least 512 bytes
17*da0073e9SAndroid Build Coastguard Worker constexpr size_t kMinBlockSize = 512;
18*da0073e9SAndroid Build Coastguard Worker // largest "small" allocation is 1 MiB
19*da0073e9SAndroid Build Coastguard Worker constexpr size_t kSmallSize = 1048576;
20*da0073e9SAndroid Build Coastguard Worker // "small" allocations are packed in 2 MiB blocks
21*da0073e9SAndroid Build Coastguard Worker constexpr size_t kSmallBuffer = 2097152;
22*da0073e9SAndroid Build Coastguard Worker // "large" allocations may be packed in 20 MiB blocks
23*da0073e9SAndroid Build Coastguard Worker constexpr size_t kLargeBuffer = 20971520;
24*da0073e9SAndroid Build Coastguard Worker // allocations between 1 and 10 MiB may use kLargeBuffer
25*da0073e9SAndroid Build Coastguard Worker constexpr size_t kMinLargeAlloc = 10485760;
26*da0073e9SAndroid Build Coastguard Worker // round up large allocations to 2 MiB
27*da0073e9SAndroid Build Coastguard Worker constexpr size_t kRoundLarge = 2097152;
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker namespace {
30*da0073e9SAndroid Build Coastguard Worker using stream_set = ska::flat_hash_set<xpu::XPUStream>;
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker struct Block;
33*da0073e9SAndroid Build Coastguard Worker typedef bool (*Comparison)(const Block*, const Block*);
34*da0073e9SAndroid Build Coastguard Worker bool BlockComparatorSize(const Block* a, const Block* b);
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker struct BlockPool {
BlockPoolc10::xpu::XPUCachingAllocator::__anon8131195a0111::BlockPool37*da0073e9SAndroid Build Coastguard Worker BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
38*da0073e9SAndroid Build Coastguard Worker std::set<Block*, Comparison> blocks;
39*da0073e9SAndroid Build Coastguard Worker const bool is_small;
40*da0073e9SAndroid Build Coastguard Worker };
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker struct Block {
43*da0073e9SAndroid Build Coastguard Worker DeviceIndex device;
44*da0073e9SAndroid Build Coastguard Worker sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
45*da0073e9SAndroid Build Coastguard Worker stream_set stream_uses; // streams on which the block was used
46*da0073e9SAndroid Build Coastguard Worker size_t size; // block size in bytes
47*da0073e9SAndroid Build Coastguard Worker size_t requested_size; // memory originally requested
48*da0073e9SAndroid Build Coastguard Worker BlockPool* pool{nullptr}; // owning memory pool
49*da0073e9SAndroid Build Coastguard Worker void* ptr{nullptr}; // memory address
50*da0073e9SAndroid Build Coastguard Worker bool allocated{false}; // in-use flag
51*da0073e9SAndroid Build Coastguard Worker Block* prev{nullptr}; // prev block if split from a larger allocation
52*da0073e9SAndroid Build Coastguard Worker Block* next{nullptr}; // next block if split from a larger allocation
53*da0073e9SAndroid Build Coastguard Worker int event_count{0}; // number of outstanding XPU events
54*da0073e9SAndroid Build Coastguard Worker
Blockc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block55*da0073e9SAndroid Build Coastguard Worker Block(
56*da0073e9SAndroid Build Coastguard Worker DeviceIndex device,
57*da0073e9SAndroid Build Coastguard Worker sycl::queue* queue,
58*da0073e9SAndroid Build Coastguard Worker size_t size,
59*da0073e9SAndroid Build Coastguard Worker BlockPool* pool,
60*da0073e9SAndroid Build Coastguard Worker void* ptr)
61*da0073e9SAndroid Build Coastguard Worker : device(device),
62*da0073e9SAndroid Build Coastguard Worker queue(queue),
63*da0073e9SAndroid Build Coastguard Worker stream_uses(),
64*da0073e9SAndroid Build Coastguard Worker size(size),
65*da0073e9SAndroid Build Coastguard Worker requested_size(0),
66*da0073e9SAndroid Build Coastguard Worker pool(pool),
67*da0073e9SAndroid Build Coastguard Worker ptr(ptr) {}
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker // constructor for search key
Blockc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block70*da0073e9SAndroid Build Coastguard Worker Block(DeviceIndex device, sycl::queue* queue, size_t size)
71*da0073e9SAndroid Build Coastguard Worker : device(device),
72*da0073e9SAndroid Build Coastguard Worker queue(queue),
73*da0073e9SAndroid Build Coastguard Worker stream_uses(),
74*da0073e9SAndroid Build Coastguard Worker size(size),
75*da0073e9SAndroid Build Coastguard Worker requested_size(0) {}
76*da0073e9SAndroid Build Coastguard Worker
is_splitc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block77*da0073e9SAndroid Build Coastguard Worker bool is_split() const {
78*da0073e9SAndroid Build Coastguard Worker return (prev != nullptr) || (next != nullptr);
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker };
81*da0073e9SAndroid Build Coastguard Worker
BlockComparatorSize(const Block * a,const Block * b)82*da0073e9SAndroid Build Coastguard Worker bool BlockComparatorSize(const Block* a, const Block* b) {
83*da0073e9SAndroid Build Coastguard Worker if (a->queue != b->queue) {
84*da0073e9SAndroid Build Coastguard Worker return reinterpret_cast<uintptr_t>(a->queue) <
85*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(b->queue);
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker if (a->size != b->size) {
88*da0073e9SAndroid Build Coastguard Worker return a->size < b->size;
89*da0073e9SAndroid Build Coastguard Worker }
90*da0073e9SAndroid Build Coastguard Worker return reinterpret_cast<uintptr_t>(a->ptr) <
91*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(b->ptr);
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker struct AllocParams {
AllocParamsc10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams95*da0073e9SAndroid Build Coastguard Worker AllocParams(
96*da0073e9SAndroid Build Coastguard Worker DeviceIndex device,
97*da0073e9SAndroid Build Coastguard Worker size_t size,
98*da0073e9SAndroid Build Coastguard Worker sycl::queue* queue,
99*da0073e9SAndroid Build Coastguard Worker BlockPool* pool,
100*da0073e9SAndroid Build Coastguard Worker size_t alloc_size)
101*da0073e9SAndroid Build Coastguard Worker : search_key(device, queue, size),
102*da0073e9SAndroid Build Coastguard Worker pool(pool),
103*da0073e9SAndroid Build Coastguard Worker alloc_size(alloc_size),
104*da0073e9SAndroid Build Coastguard Worker block(nullptr) {}
105*da0073e9SAndroid Build Coastguard Worker
devicec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams106*da0073e9SAndroid Build Coastguard Worker DeviceIndex device() const {
107*da0073e9SAndroid Build Coastguard Worker return search_key.device;
108*da0073e9SAndroid Build Coastguard Worker }
109*da0073e9SAndroid Build Coastguard Worker
queuec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams110*da0073e9SAndroid Build Coastguard Worker sycl::queue* queue() const {
111*da0073e9SAndroid Build Coastguard Worker return search_key.queue;
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker
sizec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams114*da0073e9SAndroid Build Coastguard Worker size_t size() const {
115*da0073e9SAndroid Build Coastguard Worker return search_key.size;
116*da0073e9SAndroid Build Coastguard Worker }
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker Block search_key;
119*da0073e9SAndroid Build Coastguard Worker BlockPool* pool;
120*da0073e9SAndroid Build Coastguard Worker size_t alloc_size;
121*da0073e9SAndroid Build Coastguard Worker Block* block;
122*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = {};
123*da0073e9SAndroid Build Coastguard Worker };
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker class DeviceCachingAllocator {
128*da0073e9SAndroid Build Coastguard Worker private:
129*da0073e9SAndroid Build Coastguard Worker mutable std::recursive_mutex mutex;
130*da0073e9SAndroid Build Coastguard Worker DeviceStats stats;
131*da0073e9SAndroid Build Coastguard Worker BlockPool large_blocks; // unallocated cached blocks larger than 1 MB
132*da0073e9SAndroid Build Coastguard Worker BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller
133*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_set<Block*> active_blocks; // allocated or in use by a stream
134*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<xpu::XPUStream, std::deque<std::pair<sycl::event, Block*>>>
135*da0073e9SAndroid Build Coastguard Worker xpu_events;
136*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_index;
137*da0073e9SAndroid Build Coastguard Worker
try_merge_blocks(Block * dst,Block * src,BlockPool & pool)138*da0073e9SAndroid Build Coastguard Worker size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
139*da0073e9SAndroid Build Coastguard Worker if (!src || src->allocated || src->event_count > 0 ||
140*da0073e9SAndroid Build Coastguard Worker !src->stream_uses.empty()) {
141*da0073e9SAndroid Build Coastguard Worker return 0;
142*da0073e9SAndroid Build Coastguard Worker }
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(dst->is_split() && src->is_split());
145*da0073e9SAndroid Build Coastguard Worker if (dst->prev == src) { // [src dst]
146*da0073e9SAndroid Build Coastguard Worker dst->ptr = src->ptr;
147*da0073e9SAndroid Build Coastguard Worker dst->prev = src->prev;
148*da0073e9SAndroid Build Coastguard Worker if (dst->prev) {
149*da0073e9SAndroid Build Coastguard Worker dst->prev->next = dst;
150*da0073e9SAndroid Build Coastguard Worker }
151*da0073e9SAndroid Build Coastguard Worker } else { // [dst src]
152*da0073e9SAndroid Build Coastguard Worker dst->next = src->next;
153*da0073e9SAndroid Build Coastguard Worker if (dst->next) {
154*da0073e9SAndroid Build Coastguard Worker dst->next->prev = dst;
155*da0073e9SAndroid Build Coastguard Worker }
156*da0073e9SAndroid Build Coastguard Worker }
157*da0073e9SAndroid Build Coastguard Worker const size_t subsumed_size = src->size;
158*da0073e9SAndroid Build Coastguard Worker dst->size += subsumed_size;
159*da0073e9SAndroid Build Coastguard Worker auto erased = pool.blocks.erase(src);
160*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
161*da0073e9SAndroid Build Coastguard Worker delete src;
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker return subsumed_size;
164*da0073e9SAndroid Build Coastguard Worker }
165*da0073e9SAndroid Build Coastguard Worker
free_block(Block * block)166*da0073e9SAndroid Build Coastguard Worker void free_block(Block* block) {
167*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
168*da0073e9SAndroid Build Coastguard Worker !block->allocated && block->event_count == 0 &&
169*da0073e9SAndroid Build Coastguard Worker block->stream_uses.empty());
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker size_t original_block_size = block->size;
172*da0073e9SAndroid Build Coastguard Worker size_t requested_size = block->requested_size;
173*da0073e9SAndroid Build Coastguard Worker auto& pool = *block->pool;
174*da0073e9SAndroid Build Coastguard Worker const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
175*da0073e9SAndroid Build Coastguard Worker for (Block* merge_candidate : merge_candidates) {
176*da0073e9SAndroid Build Coastguard Worker try_merge_blocks(block, merge_candidate, pool);
177*da0073e9SAndroid Build Coastguard Worker }
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker active_blocks.erase(block);
180*da0073e9SAndroid Build Coastguard Worker bool inserted = pool.blocks.insert(block).second;
181*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(pool);
184*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
185*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[stat_type].decrease(original_block_size);
186*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[stat_type].decrease(requested_size);
187*da0073e9SAndroid Build Coastguard Worker });
188*da0073e9SAndroid Build Coastguard Worker }
189*da0073e9SAndroid Build Coastguard Worker
process_events()190*da0073e9SAndroid Build Coastguard Worker void process_events() {
191*da0073e9SAndroid Build Coastguard Worker using namespace sycl::info;
192*da0073e9SAndroid Build Coastguard Worker for (auto it = xpu_events.begin(); it != xpu_events.end();) {
193*da0073e9SAndroid Build Coastguard Worker while (!it->second.empty()) {
194*da0073e9SAndroid Build Coastguard Worker auto& e = it->second.front();
195*da0073e9SAndroid Build Coastguard Worker auto event = e.first;
196*da0073e9SAndroid Build Coastguard Worker auto* block = e.second;
197*da0073e9SAndroid Build Coastguard Worker if (event.get_info<event::command_execution_status>() !=
198*da0073e9SAndroid Build Coastguard Worker event_command_status::complete) {
199*da0073e9SAndroid Build Coastguard Worker break;
200*da0073e9SAndroid Build Coastguard Worker }
201*da0073e9SAndroid Build Coastguard Worker block->event_count--;
202*da0073e9SAndroid Build Coastguard Worker if (block->event_count == 0) {
203*da0073e9SAndroid Build Coastguard Worker free_block(block);
204*da0073e9SAndroid Build Coastguard Worker }
205*da0073e9SAndroid Build Coastguard Worker it->second.pop_front();
206*da0073e9SAndroid Build Coastguard Worker }
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker if (it->second.empty()) {
209*da0073e9SAndroid Build Coastguard Worker it = xpu_events.erase(it);
210*da0073e9SAndroid Build Coastguard Worker } else {
211*da0073e9SAndroid Build Coastguard Worker it++;
212*da0073e9SAndroid Build Coastguard Worker }
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker }
215*da0073e9SAndroid Build Coastguard Worker
round_size(size_t size)216*da0073e9SAndroid Build Coastguard Worker static size_t round_size(size_t size) {
217*da0073e9SAndroid Build Coastguard Worker if (size < kMinBlockSize) {
218*da0073e9SAndroid Build Coastguard Worker return kMinBlockSize;
219*da0073e9SAndroid Build Coastguard Worker } else {
220*da0073e9SAndroid Build Coastguard Worker return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
221*da0073e9SAndroid Build Coastguard Worker }
222*da0073e9SAndroid Build Coastguard Worker }
223*da0073e9SAndroid Build Coastguard Worker
get_allocation_size(size_t size)224*da0073e9SAndroid Build Coastguard Worker static size_t get_allocation_size(size_t size) {
225*da0073e9SAndroid Build Coastguard Worker if (size <= kSmallSize) {
226*da0073e9SAndroid Build Coastguard Worker return kSmallBuffer;
227*da0073e9SAndroid Build Coastguard Worker } else if (size < kMinLargeAlloc) {
228*da0073e9SAndroid Build Coastguard Worker return kLargeBuffer;
229*da0073e9SAndroid Build Coastguard Worker } else {
230*da0073e9SAndroid Build Coastguard Worker return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
231*da0073e9SAndroid Build Coastguard Worker }
232*da0073e9SAndroid Build Coastguard Worker }
233*da0073e9SAndroid Build Coastguard Worker
get_pool(size_t size)234*da0073e9SAndroid Build Coastguard Worker BlockPool& get_pool(size_t size) {
235*da0073e9SAndroid Build Coastguard Worker if (size < kSmallSize) {
236*da0073e9SAndroid Build Coastguard Worker return small_blocks;
237*da0073e9SAndroid Build Coastguard Worker } else {
238*da0073e9SAndroid Build Coastguard Worker return large_blocks;
239*da0073e9SAndroid Build Coastguard Worker }
240*da0073e9SAndroid Build Coastguard Worker }
241*da0073e9SAndroid Build Coastguard Worker
get_free_block(AllocParams & p)242*da0073e9SAndroid Build Coastguard Worker bool get_free_block(AllocParams& p) {
243*da0073e9SAndroid Build Coastguard Worker BlockPool& pool = *p.pool;
244*da0073e9SAndroid Build Coastguard Worker auto it = pool.blocks.lower_bound(&p.search_key);
245*da0073e9SAndroid Build Coastguard Worker if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
246*da0073e9SAndroid Build Coastguard Worker return false;
247*da0073e9SAndroid Build Coastguard Worker }
248*da0073e9SAndroid Build Coastguard Worker p.block = *it;
249*da0073e9SAndroid Build Coastguard Worker pool.blocks.erase(it);
250*da0073e9SAndroid Build Coastguard Worker return true;
251*da0073e9SAndroid Build Coastguard Worker }
252*da0073e9SAndroid Build Coastguard Worker
alloc_block(AllocParams & p)253*da0073e9SAndroid Build Coastguard Worker bool alloc_block(AllocParams& p) {
254*da0073e9SAndroid Build Coastguard Worker auto size = p.alloc_size;
255*da0073e9SAndroid Build Coastguard Worker auto device = p.device();
256*da0073e9SAndroid Build Coastguard Worker void* ptr = sycl::aligned_alloc_device(
257*da0073e9SAndroid Build Coastguard Worker kDeviceAlignment,
258*da0073e9SAndroid Build Coastguard Worker size,
259*da0073e9SAndroid Build Coastguard Worker xpu::get_raw_device(device),
260*da0073e9SAndroid Build Coastguard Worker xpu::get_device_context());
261*da0073e9SAndroid Build Coastguard Worker if (!ptr) {
262*da0073e9SAndroid Build Coastguard Worker return false;
263*da0073e9SAndroid Build Coastguard Worker }
264*da0073e9SAndroid Build Coastguard Worker p.block = new Block(device, p.queue(), size, p.pool, ptr);
265*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
266*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[stat_type].increase(size);
267*da0073e9SAndroid Build Coastguard Worker });
268*da0073e9SAndroid Build Coastguard Worker return true;
269*da0073e9SAndroid Build Coastguard Worker }
270*da0073e9SAndroid Build Coastguard Worker
synchronize_and_free_events()271*da0073e9SAndroid Build Coastguard Worker void synchronize_and_free_events() {
272*da0073e9SAndroid Build Coastguard Worker for (auto& xe : xpu_events) {
273*da0073e9SAndroid Build Coastguard Worker for (auto& e : xe.second) {
274*da0073e9SAndroid Build Coastguard Worker auto event = e.first;
275*da0073e9SAndroid Build Coastguard Worker auto* block = e.second;
276*da0073e9SAndroid Build Coastguard Worker event.wait();
277*da0073e9SAndroid Build Coastguard Worker block->event_count--;
278*da0073e9SAndroid Build Coastguard Worker if (block->event_count == 0) {
279*da0073e9SAndroid Build Coastguard Worker free_block(block);
280*da0073e9SAndroid Build Coastguard Worker }
281*da0073e9SAndroid Build Coastguard Worker }
282*da0073e9SAndroid Build Coastguard Worker }
283*da0073e9SAndroid Build Coastguard Worker xpu_events.clear();
284*da0073e9SAndroid Build Coastguard Worker }
285*da0073e9SAndroid Build Coastguard Worker
release_block(Block * block)286*da0073e9SAndroid Build Coastguard Worker void release_block(Block* block) {
287*da0073e9SAndroid Build Coastguard Worker /*
288*da0073e9SAndroid Build Coastguard Worker * Note [Safe to Free Blocks on BlockPool]
289*da0073e9SAndroid Build Coastguard Worker *
290*da0073e9SAndroid Build Coastguard Worker * Callers must ensure that all accesses to the block, whose raw pointer is
291*da0073e9SAndroid Build Coastguard Worker * allocated by SYCL APIs, have been completed before invoking sycl::free.
292*da0073e9SAndroid Build Coastguard Worker *
293*da0073e9SAndroid Build Coastguard Worker * We have to do a device-level synchronization before free these blocks to
294*da0073e9SAndroid Build Coastguard Worker * guarantee that all kernels can access to the blocks have finished.
295*da0073e9SAndroid Build Coastguard Worker */
296*da0073e9SAndroid Build Coastguard Worker sycl::free(block->ptr, xpu::get_device_context());
297*da0073e9SAndroid Build Coastguard Worker auto* pool = block->pool;
298*da0073e9SAndroid Build Coastguard Worker pool->blocks.erase(block);
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(*pool);
301*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
302*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[stat_type].decrease(block->size);
303*da0073e9SAndroid Build Coastguard Worker });
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker delete block;
306*da0073e9SAndroid Build Coastguard Worker }
307*da0073e9SAndroid Build Coastguard Worker
release_blocks(BlockPool & pool)308*da0073e9SAndroid Build Coastguard Worker void release_blocks(BlockPool& pool) {
309*da0073e9SAndroid Build Coastguard Worker auto it = pool.blocks.begin();
310*da0073e9SAndroid Build Coastguard Worker while (it != pool.blocks.end()) {
311*da0073e9SAndroid Build Coastguard Worker Block* block = *it;
312*da0073e9SAndroid Build Coastguard Worker ++it;
313*da0073e9SAndroid Build Coastguard Worker if (!block->prev && !block->next) {
314*da0073e9SAndroid Build Coastguard Worker release_block(block);
315*da0073e9SAndroid Build Coastguard Worker }
316*da0073e9SAndroid Build Coastguard Worker }
317*da0073e9SAndroid Build Coastguard Worker }
318*da0073e9SAndroid Build Coastguard Worker
release_cached_blocks()319*da0073e9SAndroid Build Coastguard Worker bool release_cached_blocks() {
320*da0073e9SAndroid Build Coastguard Worker synchronize_and_free_events();
321*da0073e9SAndroid Build Coastguard Worker // See Note [Safe to Free Blocks on BlockPool]
322*da0073e9SAndroid Build Coastguard Worker c10::xpu::syncStreamsOnDevice(device_index);
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker release_blocks(large_blocks);
325*da0073e9SAndroid Build Coastguard Worker release_blocks(small_blocks);
326*da0073e9SAndroid Build Coastguard Worker return true;
327*da0073e9SAndroid Build Coastguard Worker }
328*da0073e9SAndroid Build Coastguard Worker
should_split(const Block * block,size_t size)329*da0073e9SAndroid Build Coastguard Worker bool should_split(const Block* block, size_t size) {
330*da0073e9SAndroid Build Coastguard Worker size_t remaining = block->size - size;
331*da0073e9SAndroid Build Coastguard Worker if (block->pool->is_small) {
332*da0073e9SAndroid Build Coastguard Worker return remaining >= kMinBlockSize;
333*da0073e9SAndroid Build Coastguard Worker } else {
334*da0073e9SAndroid Build Coastguard Worker return remaining > kSmallSize;
335*da0073e9SAndroid Build Coastguard Worker }
336*da0073e9SAndroid Build Coastguard Worker }
337*da0073e9SAndroid Build Coastguard Worker
get_stat_types_for_pool(const BlockPool & pool)338*da0073e9SAndroid Build Coastguard Worker StatTypes get_stat_types_for_pool(const BlockPool& pool) {
339*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = {};
340*da0073e9SAndroid Build Coastguard Worker stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
341*da0073e9SAndroid Build Coastguard Worker stat_types[static_cast<size_t>(
342*da0073e9SAndroid Build Coastguard Worker pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
343*da0073e9SAndroid Build Coastguard Worker return stat_types;
344*da0073e9SAndroid Build Coastguard Worker }
345*da0073e9SAndroid Build Coastguard Worker
alloc_found_block(AllocParams params,size_t orig_size,bool split_remainder)346*da0073e9SAndroid Build Coastguard Worker Block* alloc_found_block(
347*da0073e9SAndroid Build Coastguard Worker AllocParams params,
348*da0073e9SAndroid Build Coastguard Worker size_t orig_size,
349*da0073e9SAndroid Build Coastguard Worker bool split_remainder) {
350*da0073e9SAndroid Build Coastguard Worker auto size = params.size();
351*da0073e9SAndroid Build Coastguard Worker auto device = params.device();
352*da0073e9SAndroid Build Coastguard Worker BlockPool* pool = params.pool;
353*da0073e9SAndroid Build Coastguard Worker sycl::queue* queue = params.queue();
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
356*da0073e9SAndroid Build Coastguard Worker params.block != nullptr && params.block->ptr != nullptr);
357*da0073e9SAndroid Build Coastguard Worker Block* block = params.block;
358*da0073e9SAndroid Build Coastguard Worker Block* remaining = nullptr;
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker if (split_remainder) {
361*da0073e9SAndroid Build Coastguard Worker remaining = block;
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker block = new Block(device, queue, size, pool, block->ptr);
364*da0073e9SAndroid Build Coastguard Worker block->prev = remaining->prev;
365*da0073e9SAndroid Build Coastguard Worker if (block->prev) {
366*da0073e9SAndroid Build Coastguard Worker block->prev->next = block;
367*da0073e9SAndroid Build Coastguard Worker }
368*da0073e9SAndroid Build Coastguard Worker block->next = remaining;
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker remaining->prev = block;
371*da0073e9SAndroid Build Coastguard Worker remaining->ptr = static_cast<char*>(remaining->ptr) + size;
372*da0073e9SAndroid Build Coastguard Worker remaining->size -= size;
373*da0073e9SAndroid Build Coastguard Worker bool inserted = pool->blocks.insert(remaining).second;
374*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
375*da0073e9SAndroid Build Coastguard Worker }
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker block->allocated = true;
378*da0073e9SAndroid Build Coastguard Worker block->requested_size = orig_size;
379*da0073e9SAndroid Build Coastguard Worker bool inserted = active_blocks.insert(block).second;
380*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
383*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[stat_type].increase(block->size);
384*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[stat_type].increase(block->size);
385*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[stat_type].increase(block->requested_size);
386*da0073e9SAndroid Build Coastguard Worker });
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker return block;
389*da0073e9SAndroid Build Coastguard Worker }
390*da0073e9SAndroid Build Coastguard Worker
insert_events(Block * block)391*da0073e9SAndroid Build Coastguard Worker void insert_events(Block* block) {
392*da0073e9SAndroid Build Coastguard Worker stream_set streams(std::move(block->stream_uses));
393*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(block->stream_uses.empty());
394*da0073e9SAndroid Build Coastguard Worker for (auto& stream : streams) {
395*da0073e9SAndroid Build Coastguard Worker block->event_count++;
396*da0073e9SAndroid Build Coastguard Worker xpu_events[stream].emplace_back(
397*da0073e9SAndroid Build Coastguard Worker stream.queue().ext_oneapi_submit_barrier(), block);
398*da0073e9SAndroid Build Coastguard Worker }
399*da0073e9SAndroid Build Coastguard Worker }
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker public:
DeviceCachingAllocator(DeviceIndex device_index)402*da0073e9SAndroid Build Coastguard Worker DeviceCachingAllocator(DeviceIndex device_index)
403*da0073e9SAndroid Build Coastguard Worker : large_blocks(/* small */ false),
404*da0073e9SAndroid Build Coastguard Worker small_blocks(/* small */ true),
405*da0073e9SAndroid Build Coastguard Worker device_index(device_index) {}
406*da0073e9SAndroid Build Coastguard Worker
malloc(DeviceIndex device,size_t orig_size,sycl::queue & queue)407*da0073e9SAndroid Build Coastguard Worker Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
408*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::recursive_mutex> lock(mutex);
409*da0073e9SAndroid Build Coastguard Worker process_events();
410*da0073e9SAndroid Build Coastguard Worker size_t size = round_size(orig_size);
411*da0073e9SAndroid Build Coastguard Worker auto& pool = get_pool(size);
412*da0073e9SAndroid Build Coastguard Worker const size_t alloc_size = get_allocation_size(size);
413*da0073e9SAndroid Build Coastguard Worker AllocParams params(device, size, &queue, &pool, alloc_size);
414*da0073e9SAndroid Build Coastguard Worker params.stat_types = get_stat_types_for_pool(pool);
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker // First, try to get a block from the existing pool.
417*da0073e9SAndroid Build Coastguard Worker bool block_found = get_free_block(params);
418*da0073e9SAndroid Build Coastguard Worker // Can't reuse an existing block, try to get a new one.
419*da0073e9SAndroid Build Coastguard Worker if (!block_found) {
420*da0073e9SAndroid Build Coastguard Worker block_found = alloc_block(params) ||
421*da0073e9SAndroid Build Coastguard Worker (release_cached_blocks() && alloc_block(params));
422*da0073e9SAndroid Build Coastguard Worker }
423*da0073e9SAndroid Build Coastguard Worker if (!block_found) {
424*da0073e9SAndroid Build Coastguard Worker c10::xpu::DeviceProp device_prop;
425*da0073e9SAndroid Build Coastguard Worker c10::xpu::get_device_properties(&device_prop, device);
426*da0073e9SAndroid Build Coastguard Worker auto device_total = device_prop.global_mem_size;
427*da0073e9SAndroid Build Coastguard Worker auto allocated_bytes =
428*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
429*da0073e9SAndroid Build Coastguard Worker .current;
430*da0073e9SAndroid Build Coastguard Worker auto reserved_bytes =
431*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
432*da0073e9SAndroid Build Coastguard Worker .current;
433*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_WITH(
434*da0073e9SAndroid Build Coastguard Worker OutOfMemoryError,
435*da0073e9SAndroid Build Coastguard Worker false,
436*da0073e9SAndroid Build Coastguard Worker "XPU out of memory. Tried to allocate ",
437*da0073e9SAndroid Build Coastguard Worker format_size(alloc_size),
438*da0073e9SAndroid Build Coastguard Worker ". GPU ",
439*da0073e9SAndroid Build Coastguard Worker static_cast<int>(device),
440*da0073e9SAndroid Build Coastguard Worker " has a total capacity of ",
441*da0073e9SAndroid Build Coastguard Worker format_size(device_total),
442*da0073e9SAndroid Build Coastguard Worker ". Of the allocated memory ",
443*da0073e9SAndroid Build Coastguard Worker format_size(allocated_bytes),
444*da0073e9SAndroid Build Coastguard Worker " is allocated by PyTorch, and ",
445*da0073e9SAndroid Build Coastguard Worker format_size(reserved_bytes - allocated_bytes),
446*da0073e9SAndroid Build Coastguard Worker " is reserved by PyTorch but unallocated.",
447*da0073e9SAndroid Build Coastguard Worker " Please use `empty_cache` to release all unoccupied cached memory.");
448*da0073e9SAndroid Build Coastguard Worker }
449*da0073e9SAndroid Build Coastguard Worker bool split_remainder = should_split(params.block, params.size());
450*da0073e9SAndroid Build Coastguard Worker return alloc_found_block(std::move(params), orig_size, split_remainder);
451*da0073e9SAndroid Build Coastguard Worker }
452*da0073e9SAndroid Build Coastguard Worker
free(Block * block)453*da0073e9SAndroid Build Coastguard Worker void free(Block* block) {
454*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::recursive_mutex> lock(mutex);
455*da0073e9SAndroid Build Coastguard Worker block->allocated = false;
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker StatTypes stat_types = get_stat_types_for_pool(*block->pool);
458*da0073e9SAndroid Build Coastguard Worker for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
459*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[stat_type].decrease(block->size);
460*da0073e9SAndroid Build Coastguard Worker });
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker if (!block->stream_uses.empty()) {
463*da0073e9SAndroid Build Coastguard Worker insert_events(block);
464*da0073e9SAndroid Build Coastguard Worker } else {
465*da0073e9SAndroid Build Coastguard Worker free_block(block);
466*da0073e9SAndroid Build Coastguard Worker }
467*da0073e9SAndroid Build Coastguard Worker }
468*da0073e9SAndroid Build Coastguard Worker
recordStream(Block * block,xpu::XPUStream stream)469*da0073e9SAndroid Build Coastguard Worker void recordStream(Block* block, xpu::XPUStream stream) {
470*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::recursive_mutex> lock(mutex);
471*da0073e9SAndroid Build Coastguard Worker if (stream.queue() == *block->queue) {
472*da0073e9SAndroid Build Coastguard Worker return;
473*da0073e9SAndroid Build Coastguard Worker }
474*da0073e9SAndroid Build Coastguard Worker block->stream_uses.insert(stream);
475*da0073e9SAndroid Build Coastguard Worker }
476*da0073e9SAndroid Build Coastguard Worker
emptyCache()477*da0073e9SAndroid Build Coastguard Worker void emptyCache() {
478*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::recursive_mutex> lock(mutex);
479*da0073e9SAndroid Build Coastguard Worker release_cached_blocks();
480*da0073e9SAndroid Build Coastguard Worker }
481*da0073e9SAndroid Build Coastguard Worker
getStats()482*da0073e9SAndroid Build Coastguard Worker DeviceStats getStats() {
483*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::recursive_mutex> lock(mutex);
484*da0073e9SAndroid Build Coastguard Worker return stats;
485*da0073e9SAndroid Build Coastguard Worker }
486*da0073e9SAndroid Build Coastguard Worker
resetAccumulatedStats()487*da0073e9SAndroid Build Coastguard Worker void resetAccumulatedStats() {
488*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::recursive_mutex> lock(mutex);
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker for (const auto statType :
491*da0073e9SAndroid Build Coastguard Worker c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
492*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[statType].reset_accumulated();
493*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[statType].reset_accumulated();
494*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[statType].reset_accumulated();
495*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[statType].reset_accumulated();
496*da0073e9SAndroid Build Coastguard Worker }
497*da0073e9SAndroid Build Coastguard Worker }
498*da0073e9SAndroid Build Coastguard Worker
resetPeakStats()499*da0073e9SAndroid Build Coastguard Worker void resetPeakStats() {
500*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::recursive_mutex> lock(mutex);
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker for (const auto statType :
503*da0073e9SAndroid Build Coastguard Worker c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
504*da0073e9SAndroid Build Coastguard Worker stats.allocated_bytes[statType].reset_peak();
505*da0073e9SAndroid Build Coastguard Worker stats.reserved_bytes[statType].reset_peak();
506*da0073e9SAndroid Build Coastguard Worker stats.active_bytes[statType].reset_peak();
507*da0073e9SAndroid Build Coastguard Worker stats.requested_bytes[statType].reset_peak();
508*da0073e9SAndroid Build Coastguard Worker }
509*da0073e9SAndroid Build Coastguard Worker }
510*da0073e9SAndroid Build Coastguard Worker };
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker void local_raw_delete(void* ptr);
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker class XPUAllocator : public Allocator {
515*da0073e9SAndroid Build Coastguard Worker private:
516*da0073e9SAndroid Build Coastguard Worker std::mutex mutex;
517*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<void*, Block*> allocated_blocks;
518*da0073e9SAndroid Build Coastguard Worker
add_allocated_block(Block * block)519*da0073e9SAndroid Build Coastguard Worker void add_allocated_block(Block* block) {
520*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(mutex);
521*da0073e9SAndroid Build Coastguard Worker allocated_blocks[block->ptr] = block;
522*da0073e9SAndroid Build Coastguard Worker }
523*da0073e9SAndroid Build Coastguard Worker
get_allocated_block(void * ptr,bool remove=false)524*da0073e9SAndroid Build Coastguard Worker Block* get_allocated_block(void* ptr, bool remove = false) {
525*da0073e9SAndroid Build Coastguard Worker std::scoped_lock<std::mutex> lock(mutex);
526*da0073e9SAndroid Build Coastguard Worker auto it = allocated_blocks.find(ptr);
527*da0073e9SAndroid Build Coastguard Worker if (it == allocated_blocks.end()) {
528*da0073e9SAndroid Build Coastguard Worker return nullptr;
529*da0073e9SAndroid Build Coastguard Worker }
530*da0073e9SAndroid Build Coastguard Worker Block* block = it->second;
531*da0073e9SAndroid Build Coastguard Worker if (remove) {
532*da0073e9SAndroid Build Coastguard Worker allocated_blocks.erase(it);
533*da0073e9SAndroid Build Coastguard Worker }
534*da0073e9SAndroid Build Coastguard Worker return block;
535*da0073e9SAndroid Build Coastguard Worker }
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker public:
538*da0073e9SAndroid Build Coastguard Worker std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
539*da0073e9SAndroid Build Coastguard Worker
init(DeviceIndex device_count)540*da0073e9SAndroid Build Coastguard Worker void init(DeviceIndex device_count) {
541*da0073e9SAndroid Build Coastguard Worker const auto size = static_cast<DeviceIndex>(device_allocators.size());
542*da0073e9SAndroid Build Coastguard Worker if (size < device_count) {
543*da0073e9SAndroid Build Coastguard Worker device_allocators.resize(device_count);
544*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(size, device_count)) {
545*da0073e9SAndroid Build Coastguard Worker device_allocators[i] = std::make_unique<DeviceCachingAllocator>(i);
546*da0073e9SAndroid Build Coastguard Worker }
547*da0073e9SAndroid Build Coastguard Worker }
548*da0073e9SAndroid Build Coastguard Worker }
549*da0073e9SAndroid Build Coastguard Worker
malloc(void ** devPtr,DeviceIndex device,size_t size,sycl::queue & queue)550*da0073e9SAndroid Build Coastguard Worker void malloc(
551*da0073e9SAndroid Build Coastguard Worker void** devPtr,
552*da0073e9SAndroid Build Coastguard Worker DeviceIndex device,
553*da0073e9SAndroid Build Coastguard Worker size_t size,
554*da0073e9SAndroid Build Coastguard Worker sycl::queue& queue) {
555*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
556*da0073e9SAndroid Build Coastguard Worker 0 <= device && static_cast<size_t>(device) < device_allocators.size(),
557*da0073e9SAndroid Build Coastguard Worker "Allocator not initialized for device ",
558*da0073e9SAndroid Build Coastguard Worker static_cast<int16_t>(device),
559*da0073e9SAndroid Build Coastguard Worker ": did you call init?");
560*da0073e9SAndroid Build Coastguard Worker Block* block = device_allocators[device]->malloc(device, size, queue);
561*da0073e9SAndroid Build Coastguard Worker add_allocated_block(block);
562*da0073e9SAndroid Build Coastguard Worker *devPtr = block->ptr;
563*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
564*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
565*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_memory_allocation(
566*da0073e9SAndroid Build Coastguard Worker c10::kXPU, reinterpret_cast<uintptr_t>(*devPtr));
567*da0073e9SAndroid Build Coastguard Worker }
568*da0073e9SAndroid Build Coastguard Worker }
569*da0073e9SAndroid Build Coastguard Worker
free(void * ptr)570*da0073e9SAndroid Build Coastguard Worker void free(void* ptr) {
571*da0073e9SAndroid Build Coastguard Worker if (!ptr) {
572*da0073e9SAndroid Build Coastguard Worker return;
573*da0073e9SAndroid Build Coastguard Worker }
574*da0073e9SAndroid Build Coastguard Worker Block* block = get_allocated_block(ptr, /* remove */ true);
575*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(block, "invalid device pointer: ", ptr);
576*da0073e9SAndroid Build Coastguard Worker device_allocators[block->device]->free(block);
577*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
578*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
579*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_memory_deallocation(
580*da0073e9SAndroid Build Coastguard Worker c10::kXPU, reinterpret_cast<uintptr_t>(block->ptr));
581*da0073e9SAndroid Build Coastguard Worker }
582*da0073e9SAndroid Build Coastguard Worker }
583*da0073e9SAndroid Build Coastguard Worker
emptyCache()584*da0073e9SAndroid Build Coastguard Worker void emptyCache() {
585*da0073e9SAndroid Build Coastguard Worker for (auto& da : device_allocators) {
586*da0073e9SAndroid Build Coastguard Worker da->emptyCache();
587*da0073e9SAndroid Build Coastguard Worker }
588*da0073e9SAndroid Build Coastguard Worker }
589*da0073e9SAndroid Build Coastguard Worker
recordStream(const DataPtr & ptr,XPUStream stream)590*da0073e9SAndroid Build Coastguard Worker void recordStream(const DataPtr& ptr, XPUStream stream) {
591*da0073e9SAndroid Build Coastguard Worker if (!ptr.get()) {
592*da0073e9SAndroid Build Coastguard Worker return;
593*da0073e9SAndroid Build Coastguard Worker }
594*da0073e9SAndroid Build Coastguard Worker if (ptr.get_deleter() != &local_raw_delete) {
595*da0073e9SAndroid Build Coastguard Worker return;
596*da0073e9SAndroid Build Coastguard Worker }
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker Block* block = get_allocated_block(ptr.get());
599*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(block, "No allocated block can be found.");
600*da0073e9SAndroid Build Coastguard Worker device_allocators[block->device]->recordStream(block, stream);
601*da0073e9SAndroid Build Coastguard Worker }
602*da0073e9SAndroid Build Coastguard Worker
allocate(size_t size)603*da0073e9SAndroid Build Coastguard Worker DataPtr allocate(size_t size) override {
604*da0073e9SAndroid Build Coastguard Worker auto device = c10::xpu::current_device();
605*da0073e9SAndroid Build Coastguard Worker void* r = nullptr;
606*da0073e9SAndroid Build Coastguard Worker if (size != 0) {
607*da0073e9SAndroid Build Coastguard Worker this->malloc(&r, device, size, xpu::getCurrentXPUStream(device));
608*da0073e9SAndroid Build Coastguard Worker }
609*da0073e9SAndroid Build Coastguard Worker return {r, r, &local_raw_delete, Device(DeviceType::XPU, device)};
610*da0073e9SAndroid Build Coastguard Worker }
611*da0073e9SAndroid Build Coastguard Worker
raw_deleter() const612*da0073e9SAndroid Build Coastguard Worker DeleterFnPtr raw_deleter() const override {
613*da0073e9SAndroid Build Coastguard Worker return &local_raw_delete;
614*da0073e9SAndroid Build Coastguard Worker }
615*da0073e9SAndroid Build Coastguard Worker
raw_alloc(size_t size)616*da0073e9SAndroid Build Coastguard Worker void* raw_alloc(size_t size) {
617*da0073e9SAndroid Build Coastguard Worker if (size == 0) {
618*da0073e9SAndroid Build Coastguard Worker return nullptr;
619*da0073e9SAndroid Build Coastguard Worker }
620*da0073e9SAndroid Build Coastguard Worker auto device = c10::xpu::current_device();
621*da0073e9SAndroid Build Coastguard Worker void* r = nullptr;
622*da0073e9SAndroid Build Coastguard Worker malloc(&r, device, size, xpu::getCurrentXPUStream(device));
623*da0073e9SAndroid Build Coastguard Worker return r;
624*da0073e9SAndroid Build Coastguard Worker }
625*da0073e9SAndroid Build Coastguard Worker
raw_alloc_with_stream(size_t size,XPUStream stream)626*da0073e9SAndroid Build Coastguard Worker void* raw_alloc_with_stream(size_t size, XPUStream stream) {
627*da0073e9SAndroid Build Coastguard Worker if (size == 0) {
628*da0073e9SAndroid Build Coastguard Worker return nullptr;
629*da0073e9SAndroid Build Coastguard Worker }
630*da0073e9SAndroid Build Coastguard Worker auto device = c10::xpu::current_device();
631*da0073e9SAndroid Build Coastguard Worker void* r = nullptr;
632*da0073e9SAndroid Build Coastguard Worker malloc(&r, device, size, stream);
633*da0073e9SAndroid Build Coastguard Worker return r;
634*da0073e9SAndroid Build Coastguard Worker }
635*da0073e9SAndroid Build Coastguard Worker
raw_delete(void * ptr)636*da0073e9SAndroid Build Coastguard Worker void raw_delete(void* ptr) {
637*da0073e9SAndroid Build Coastguard Worker this->free(ptr);
638*da0073e9SAndroid Build Coastguard Worker }
639*da0073e9SAndroid Build Coastguard Worker
copy_data(void * dest,const void * src,std::size_t count) const640*da0073e9SAndroid Build Coastguard Worker void copy_data(void* dest, const void* src, std::size_t count) const final {
641*da0073e9SAndroid Build Coastguard Worker xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
642*da0073e9SAndroid Build Coastguard Worker }
643*da0073e9SAndroid Build Coastguard Worker
assertValidDevice(DeviceIndex device)644*da0073e9SAndroid Build Coastguard Worker void assertValidDevice(DeviceIndex device) {
645*da0073e9SAndroid Build Coastguard Worker const auto device_num = device_allocators.size();
646*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
647*da0073e9SAndroid Build Coastguard Worker 0 <= device && device < static_cast<int64_t>(device_num),
648*da0073e9SAndroid Build Coastguard Worker "Invalid device argument ",
649*da0073e9SAndroid Build Coastguard Worker device,
650*da0073e9SAndroid Build Coastguard Worker ": did you call init?");
651*da0073e9SAndroid Build Coastguard Worker }
652*da0073e9SAndroid Build Coastguard Worker
getDeviceStats(DeviceIndex device)653*da0073e9SAndroid Build Coastguard Worker DeviceStats getDeviceStats(DeviceIndex device) {
654*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
655*da0073e9SAndroid Build Coastguard Worker return device_allocators[device]->getStats();
656*da0073e9SAndroid Build Coastguard Worker }
657*da0073e9SAndroid Build Coastguard Worker
resetPeakStats(DeviceIndex device)658*da0073e9SAndroid Build Coastguard Worker void resetPeakStats(DeviceIndex device) {
659*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
660*da0073e9SAndroid Build Coastguard Worker device_allocators[device]->resetPeakStats();
661*da0073e9SAndroid Build Coastguard Worker }
662*da0073e9SAndroid Build Coastguard Worker
resetAccumulatedStats(DeviceIndex device)663*da0073e9SAndroid Build Coastguard Worker void resetAccumulatedStats(DeviceIndex device) {
664*da0073e9SAndroid Build Coastguard Worker assertValidDevice(device);
665*da0073e9SAndroid Build Coastguard Worker device_allocators[device]->resetAccumulatedStats();
666*da0073e9SAndroid Build Coastguard Worker }
667*da0073e9SAndroid Build Coastguard Worker };
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker static XPUAllocator allocator;
670*da0073e9SAndroid Build Coastguard Worker
local_raw_delete(void * ptr)671*da0073e9SAndroid Build Coastguard Worker void local_raw_delete(void* ptr) {
672*da0073e9SAndroid Build Coastguard Worker allocator.free(ptr);
673*da0073e9SAndroid Build Coastguard Worker }
674*da0073e9SAndroid Build Coastguard Worker
get()675*da0073e9SAndroid Build Coastguard Worker Allocator* get() {
676*da0073e9SAndroid Build Coastguard Worker return &allocator;
677*da0073e9SAndroid Build Coastguard Worker }
678*da0073e9SAndroid Build Coastguard Worker
init(DeviceIndex device_count)679*da0073e9SAndroid Build Coastguard Worker void init(DeviceIndex device_count) {
680*da0073e9SAndroid Build Coastguard Worker return allocator.init(device_count);
681*da0073e9SAndroid Build Coastguard Worker }
682*da0073e9SAndroid Build Coastguard Worker
emptyCache()683*da0073e9SAndroid Build Coastguard Worker void emptyCache() {
684*da0073e9SAndroid Build Coastguard Worker return allocator.emptyCache();
685*da0073e9SAndroid Build Coastguard Worker }
686*da0073e9SAndroid Build Coastguard Worker
resetPeakStats(DeviceIndex device)687*da0073e9SAndroid Build Coastguard Worker void resetPeakStats(DeviceIndex device) {
688*da0073e9SAndroid Build Coastguard Worker return allocator.resetPeakStats(device);
689*da0073e9SAndroid Build Coastguard Worker }
690*da0073e9SAndroid Build Coastguard Worker
resetAccumulatedStats(DeviceIndex device)691*da0073e9SAndroid Build Coastguard Worker void resetAccumulatedStats(DeviceIndex device) {
692*da0073e9SAndroid Build Coastguard Worker return allocator.resetAccumulatedStats(device);
693*da0073e9SAndroid Build Coastguard Worker }
694*da0073e9SAndroid Build Coastguard Worker
getDeviceStats(DeviceIndex device)695*da0073e9SAndroid Build Coastguard Worker DeviceStats getDeviceStats(DeviceIndex device) {
696*da0073e9SAndroid Build Coastguard Worker return allocator.getDeviceStats(device);
697*da0073e9SAndroid Build Coastguard Worker }
698*da0073e9SAndroid Build Coastguard Worker
raw_alloc(size_t size)699*da0073e9SAndroid Build Coastguard Worker void* raw_alloc(size_t size) {
700*da0073e9SAndroid Build Coastguard Worker return allocator.raw_alloc(size);
701*da0073e9SAndroid Build Coastguard Worker }
702*da0073e9SAndroid Build Coastguard Worker
raw_delete(void * ptr)703*da0073e9SAndroid Build Coastguard Worker void raw_delete(void* ptr) {
704*da0073e9SAndroid Build Coastguard Worker return allocator.raw_delete(ptr);
705*da0073e9SAndroid Build Coastguard Worker }
706*da0073e9SAndroid Build Coastguard Worker
recordStream(const DataPtr & dataPtr,XPUStream stream)707*da0073e9SAndroid Build Coastguard Worker void recordStream(const DataPtr& dataPtr, XPUStream stream) {
708*da0073e9SAndroid Build Coastguard Worker return allocator.recordStream(dataPtr, stream);
709*da0073e9SAndroid Build Coastguard Worker }
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker REGISTER_ALLOCATOR(kXPU, &allocator)
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu::XPUCachingAllocator
714