xref: /aosp_15_r20/external/pytorch/c10/mobile/CPUCachingAllocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/alloc_cpu.h>
2 #include <c10/mobile/CPUCachingAllocator.h>
3 #include <c10/util/Exception.h>
4 
5 namespace c10 {
6 
7 namespace {
8 thread_local CPUCachingAllocator* caching_allocator_ptr{nullptr};
9 } // namespace
10 
11 std::mutex CPUCachingAllocator::mutex_;
12 ska::flat_hash_map<void*, size_t> CPUCachingAllocator::allocation_map_;
13 
allocate_and_cache(const size_t bytes)14 inline void* CPUCachingAllocator::allocate_and_cache(const size_t bytes) {
15   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
16   void* ptr;
17   try {
18     ptr = c10::alloc_cpu(bytes);
19   } catch (c10::Error&) {
20     // If allocation fails, try freeing cached available blocks.
21     // For now free all available cached blocks.
22     free_cached();
23     // Furthermore to consider: If we ever come here running out of memory
24     // perhaps it is best to disable caching, since this is likely to happen
25     // again.
26     // Try again.
27     ptr = c10::alloc_cpu(bytes);
28   }
29   allocation_map_[ptr] = bytes;
30   return ptr;
31 }
32 
allocate(const size_t bytes)33 void* CPUCachingAllocator::allocate(const size_t bytes) {
34   std::lock_guard<std::mutex> guard(mutex_);
35   const auto& it = available_map_.find(bytes);
36   if (it == available_map_.end() || it->second.empty()) {
37     return allocate_and_cache(bytes);
38   }
39   return it->second.pop_back_val();
40 }
41 
free(void * ptr)42 void CPUCachingAllocator::free(void* ptr) {
43   // NB: since we are not really freeing the memory
44   // the cases such as quantization code freeing original weights
45   // on mobile, will not quite work, as we likely will hold
46   // onto that memory.
47   // NB: We can also enable max memory cached for better memory
48   // management such that free will actually free the memory if
49   // we are nearing or above the watermark.
50   std::lock_guard<std::mutex> guard(mutex_);
51   // If this allocation was done before caching allocator was enabled
52   // then free regularly
53   const auto& it = allocation_map_.find(ptr);
54   if (it == allocation_map_.end()) {
55     c10::free_cpu(ptr);
56     return;
57   }
58   const size_t alloc_size = it->second;
59   available_map_[alloc_size].push_back(ptr);
60 }
61 
record_free(void * ptr)62 void CPUCachingAllocator::record_free(void* ptr) {
63   // This function captures the case when the allocated memory
64   // is being freed outside the scope of this allocator.
65   // At the moment only way to capture this is to have the allocator,
66   // that uses this CachingAllocator as the backing allocator,
67   // call this function explicitly upon freeing memory while
68   // outside the scope of caching allocator.
69   // If the memory is freed in some other way, then we will likely
70   // have undefined behavior or page fault. But this can be
71   // the case without caching allocator as well.
72   std::lock_guard<std::mutex> guard(mutex_);
73   const auto& it = allocation_map_.find(ptr);
74   if (it != allocation_map_.end()) {
75     allocation_map_.erase(it);
76   }
77 }
78 
free_cached()79 void CPUCachingAllocator::free_cached() {
80   for (const auto& it : available_map_) {
81     for (const auto ptr : it.second) {
82       c10::free_cpu(ptr);
83       // When cached memory is return to OS, it must be removed
84       // from allocation_map.
85       allocation_map_.erase(ptr);
86     }
87   }
88   available_map_.clear();
89 }
90 
~CPUCachingAllocator()91 CPUCachingAllocator::~CPUCachingAllocator() {
92   free_cached();
93 }
94 
GetThreadLocalCachingAllocator()95 CPUCachingAllocator* GetThreadLocalCachingAllocator() {
96   return caching_allocator_ptr;
97 }
98 
WithCPUCachingAllocatorGuard(CPUCachingAllocator * allocator)99 WithCPUCachingAllocatorGuard::WithCPUCachingAllocatorGuard(
100     CPUCachingAllocator* allocator)
101     : prev_caching_allocator_ptr_(GetThreadLocalCachingAllocator()) {
102   caching_allocator_ptr = allocator;
103 }
104 
~WithCPUCachingAllocatorGuard()105 WithCPUCachingAllocatorGuard::~WithCPUCachingAllocatorGuard() {
106   caching_allocator_ptr = prev_caching_allocator_ptr_;
107 }
108 
109 } // namespace c10
110