xref: /aosp_15_r20/external/pytorch/c10/mobile/CPUCachingAllocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/alloc_cpu.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/mobile/CPUCachingAllocator.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker namespace c10 {
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace {
8*da0073e9SAndroid Build Coastguard Worker thread_local CPUCachingAllocator* caching_allocator_ptr{nullptr};
9*da0073e9SAndroid Build Coastguard Worker } // namespace
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker std::mutex CPUCachingAllocator::mutex_;
12*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<void*, size_t> CPUCachingAllocator::allocation_map_;
13*da0073e9SAndroid Build Coastguard Worker 
allocate_and_cache(const size_t bytes)14*da0073e9SAndroid Build Coastguard Worker inline void* CPUCachingAllocator::allocate_and_cache(const size_t bytes) {
15*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
16*da0073e9SAndroid Build Coastguard Worker   void* ptr;
17*da0073e9SAndroid Build Coastguard Worker   try {
18*da0073e9SAndroid Build Coastguard Worker     ptr = c10::alloc_cpu(bytes);
19*da0073e9SAndroid Build Coastguard Worker   } catch (c10::Error&) {
20*da0073e9SAndroid Build Coastguard Worker     // If allocation fails, try freeing cached available blocks.
21*da0073e9SAndroid Build Coastguard Worker     // For now free all available cached blocks.
22*da0073e9SAndroid Build Coastguard Worker     free_cached();
23*da0073e9SAndroid Build Coastguard Worker     // Furthermore to consider: If we ever come here running out of memory
24*da0073e9SAndroid Build Coastguard Worker     // perhaps it is best to disable caching, since this is likely to happen
25*da0073e9SAndroid Build Coastguard Worker     // again.
26*da0073e9SAndroid Build Coastguard Worker     // Try again.
27*da0073e9SAndroid Build Coastguard Worker     ptr = c10::alloc_cpu(bytes);
28*da0073e9SAndroid Build Coastguard Worker   }
29*da0073e9SAndroid Build Coastguard Worker   allocation_map_[ptr] = bytes;
30*da0073e9SAndroid Build Coastguard Worker   return ptr;
31*da0073e9SAndroid Build Coastguard Worker }
32*da0073e9SAndroid Build Coastguard Worker 
allocate(const size_t bytes)33*da0073e9SAndroid Build Coastguard Worker void* CPUCachingAllocator::allocate(const size_t bytes) {
34*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(mutex_);
35*da0073e9SAndroid Build Coastguard Worker   const auto& it = available_map_.find(bytes);
36*da0073e9SAndroid Build Coastguard Worker   if (it == available_map_.end() || it->second.empty()) {
37*da0073e9SAndroid Build Coastguard Worker     return allocate_and_cache(bytes);
38*da0073e9SAndroid Build Coastguard Worker   }
39*da0073e9SAndroid Build Coastguard Worker   return it->second.pop_back_val();
40*da0073e9SAndroid Build Coastguard Worker }
41*da0073e9SAndroid Build Coastguard Worker 
free(void * ptr)42*da0073e9SAndroid Build Coastguard Worker void CPUCachingAllocator::free(void* ptr) {
43*da0073e9SAndroid Build Coastguard Worker   // NB: since we are not really freeing the memory
44*da0073e9SAndroid Build Coastguard Worker   // the cases such as quantization code freeing original weights
45*da0073e9SAndroid Build Coastguard Worker   // on mobile, will not quite work, as we likely will hold
46*da0073e9SAndroid Build Coastguard Worker   // onto that memory.
47*da0073e9SAndroid Build Coastguard Worker   // NB: We can also enable max memory cached for better memory
48*da0073e9SAndroid Build Coastguard Worker   // management such that free will actually free the memory if
49*da0073e9SAndroid Build Coastguard Worker   // we are nearing or above the watermark.
50*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(mutex_);
51*da0073e9SAndroid Build Coastguard Worker   // If this allocation was done before caching allocator was enabled
52*da0073e9SAndroid Build Coastguard Worker   // then free regularly
53*da0073e9SAndroid Build Coastguard Worker   const auto& it = allocation_map_.find(ptr);
54*da0073e9SAndroid Build Coastguard Worker   if (it == allocation_map_.end()) {
55*da0073e9SAndroid Build Coastguard Worker     c10::free_cpu(ptr);
56*da0073e9SAndroid Build Coastguard Worker     return;
57*da0073e9SAndroid Build Coastguard Worker   }
58*da0073e9SAndroid Build Coastguard Worker   const size_t alloc_size = it->second;
59*da0073e9SAndroid Build Coastguard Worker   available_map_[alloc_size].push_back(ptr);
60*da0073e9SAndroid Build Coastguard Worker }
61*da0073e9SAndroid Build Coastguard Worker 
record_free(void * ptr)62*da0073e9SAndroid Build Coastguard Worker void CPUCachingAllocator::record_free(void* ptr) {
63*da0073e9SAndroid Build Coastguard Worker   // This function captures the case when the allocated memory
64*da0073e9SAndroid Build Coastguard Worker   // is being freed outside the scope of this allocator.
65*da0073e9SAndroid Build Coastguard Worker   // At the moment only way to capture this is to have the allocator,
66*da0073e9SAndroid Build Coastguard Worker   // that uses this CachingAllocator as the backing allocator,
67*da0073e9SAndroid Build Coastguard Worker   // call this function explicitly upon freeing memory while
68*da0073e9SAndroid Build Coastguard Worker   // outside the scope of caching allocator.
69*da0073e9SAndroid Build Coastguard Worker   // If the memory is freed in some other way, then we will likely
70*da0073e9SAndroid Build Coastguard Worker   // have undefined behavior or page fault. But this can be
71*da0073e9SAndroid Build Coastguard Worker   // the case without caching allocator as well.
72*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(mutex_);
73*da0073e9SAndroid Build Coastguard Worker   const auto& it = allocation_map_.find(ptr);
74*da0073e9SAndroid Build Coastguard Worker   if (it != allocation_map_.end()) {
75*da0073e9SAndroid Build Coastguard Worker     allocation_map_.erase(it);
76*da0073e9SAndroid Build Coastguard Worker   }
77*da0073e9SAndroid Build Coastguard Worker }
78*da0073e9SAndroid Build Coastguard Worker 
free_cached()79*da0073e9SAndroid Build Coastguard Worker void CPUCachingAllocator::free_cached() {
80*da0073e9SAndroid Build Coastguard Worker   for (const auto& it : available_map_) {
81*da0073e9SAndroid Build Coastguard Worker     for (const auto ptr : it.second) {
82*da0073e9SAndroid Build Coastguard Worker       c10::free_cpu(ptr);
83*da0073e9SAndroid Build Coastguard Worker       // When cached memory is return to OS, it must be removed
84*da0073e9SAndroid Build Coastguard Worker       // from allocation_map.
85*da0073e9SAndroid Build Coastguard Worker       allocation_map_.erase(ptr);
86*da0073e9SAndroid Build Coastguard Worker     }
87*da0073e9SAndroid Build Coastguard Worker   }
88*da0073e9SAndroid Build Coastguard Worker   available_map_.clear();
89*da0073e9SAndroid Build Coastguard Worker }
90*da0073e9SAndroid Build Coastguard Worker 
~CPUCachingAllocator()91*da0073e9SAndroid Build Coastguard Worker CPUCachingAllocator::~CPUCachingAllocator() {
92*da0073e9SAndroid Build Coastguard Worker   free_cached();
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker 
GetThreadLocalCachingAllocator()95*da0073e9SAndroid Build Coastguard Worker CPUCachingAllocator* GetThreadLocalCachingAllocator() {
96*da0073e9SAndroid Build Coastguard Worker   return caching_allocator_ptr;
97*da0073e9SAndroid Build Coastguard Worker }
98*da0073e9SAndroid Build Coastguard Worker 
WithCPUCachingAllocatorGuard(CPUCachingAllocator * allocator)99*da0073e9SAndroid Build Coastguard Worker WithCPUCachingAllocatorGuard::WithCPUCachingAllocatorGuard(
100*da0073e9SAndroid Build Coastguard Worker     CPUCachingAllocator* allocator)
101*da0073e9SAndroid Build Coastguard Worker     : prev_caching_allocator_ptr_(GetThreadLocalCachingAllocator()) {
102*da0073e9SAndroid Build Coastguard Worker   caching_allocator_ptr = allocator;
103*da0073e9SAndroid Build Coastguard Worker }
104*da0073e9SAndroid Build Coastguard Worker 
~WithCPUCachingAllocatorGuard()105*da0073e9SAndroid Build Coastguard Worker WithCPUCachingAllocatorGuard::~WithCPUCachingAllocatorGuard() {
106*da0073e9SAndroid Build Coastguard Worker   caching_allocator_ptr = prev_caching_allocator_ptr_;
107*da0073e9SAndroid Build Coastguard Worker }
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker } // namespace c10
110