1*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/alloc_cpu.h> 2*da0073e9SAndroid Build Coastguard Worker #include <c10/mobile/CPUCachingAllocator.h> 3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h> 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker namespace c10 { 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker namespace { 8*da0073e9SAndroid Build Coastguard Worker thread_local CPUCachingAllocator* caching_allocator_ptr{nullptr}; 9*da0073e9SAndroid Build Coastguard Worker } // namespace 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker std::mutex CPUCachingAllocator::mutex_; 12*da0073e9SAndroid Build Coastguard Worker ska::flat_hash_map<void*, size_t> CPUCachingAllocator::allocation_map_; 13*da0073e9SAndroid Build Coastguard Worker allocate_and_cache(const size_t bytes)14*da0073e9SAndroid Build Coastguard Workerinline void* CPUCachingAllocator::allocate_and_cache(const size_t bytes) { 15*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-init-variables) 16*da0073e9SAndroid Build Coastguard Worker void* ptr; 17*da0073e9SAndroid Build Coastguard Worker try { 18*da0073e9SAndroid Build Coastguard Worker ptr = c10::alloc_cpu(bytes); 19*da0073e9SAndroid Build Coastguard Worker } catch (c10::Error&) { 20*da0073e9SAndroid Build Coastguard Worker // If allocation fails, try freeing cached available blocks. 21*da0073e9SAndroid Build Coastguard Worker // For now free all available cached blocks. 22*da0073e9SAndroid Build Coastguard Worker free_cached(); 23*da0073e9SAndroid Build Coastguard Worker // Furthermore to consider: If we ever come here running out of memory 24*da0073e9SAndroid Build Coastguard Worker // perhaps it is best to disable caching, since this is likely to happen 25*da0073e9SAndroid Build Coastguard Worker // again. 26*da0073e9SAndroid Build Coastguard Worker // Try again. 27*da0073e9SAndroid Build Coastguard Worker ptr = c10::alloc_cpu(bytes); 28*da0073e9SAndroid Build Coastguard Worker } 29*da0073e9SAndroid Build Coastguard Worker allocation_map_[ptr] = bytes; 30*da0073e9SAndroid Build Coastguard Worker return ptr; 31*da0073e9SAndroid Build Coastguard Worker } 32*da0073e9SAndroid Build Coastguard Worker allocate(const size_t bytes)33*da0073e9SAndroid Build Coastguard Workervoid* CPUCachingAllocator::allocate(const size_t bytes) { 34*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> guard(mutex_); 35*da0073e9SAndroid Build Coastguard Worker const auto& it = available_map_.find(bytes); 36*da0073e9SAndroid Build Coastguard Worker if (it == available_map_.end() || it->second.empty()) { 37*da0073e9SAndroid Build Coastguard Worker return allocate_and_cache(bytes); 38*da0073e9SAndroid Build Coastguard Worker } 39*da0073e9SAndroid Build Coastguard Worker return it->second.pop_back_val(); 40*da0073e9SAndroid Build Coastguard Worker } 41*da0073e9SAndroid Build Coastguard Worker free(void * ptr)42*da0073e9SAndroid Build Coastguard Workervoid CPUCachingAllocator::free(void* ptr) { 43*da0073e9SAndroid Build Coastguard Worker // NB: since we are not really freeing the memory 44*da0073e9SAndroid Build Coastguard Worker // the cases such as quantization code freeing original weights 45*da0073e9SAndroid Build Coastguard Worker // on mobile, will not quite work, as we likely will hold 46*da0073e9SAndroid Build Coastguard Worker // onto that memory. 47*da0073e9SAndroid Build Coastguard Worker // NB: We can also enable max memory cached for better memory 48*da0073e9SAndroid Build Coastguard Worker // management such that free will actually free the memory if 49*da0073e9SAndroid Build Coastguard Worker // we are nearing or above the watermark. 50*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> guard(mutex_); 51*da0073e9SAndroid Build Coastguard Worker // If this allocation was done before caching allocator was enabled 52*da0073e9SAndroid Build Coastguard Worker // then free regularly 53*da0073e9SAndroid Build Coastguard Worker const auto& it = allocation_map_.find(ptr); 54*da0073e9SAndroid Build Coastguard Worker if (it == allocation_map_.end()) { 55*da0073e9SAndroid Build Coastguard Worker c10::free_cpu(ptr); 56*da0073e9SAndroid Build Coastguard Worker return; 57*da0073e9SAndroid Build Coastguard Worker } 58*da0073e9SAndroid Build Coastguard Worker const size_t alloc_size = it->second; 59*da0073e9SAndroid Build Coastguard Worker available_map_[alloc_size].push_back(ptr); 60*da0073e9SAndroid Build Coastguard Worker } 61*da0073e9SAndroid Build Coastguard Worker record_free(void * ptr)62*da0073e9SAndroid Build Coastguard Workervoid CPUCachingAllocator::record_free(void* ptr) { 63*da0073e9SAndroid Build Coastguard Worker // This function captures the case when the allocated memory 64*da0073e9SAndroid Build Coastguard Worker // is being freed outside the scope of this allocator. 65*da0073e9SAndroid Build Coastguard Worker // At the moment only way to capture this is to have the allocator, 66*da0073e9SAndroid Build Coastguard Worker // that uses this CachingAllocator as the backing allocator, 67*da0073e9SAndroid Build Coastguard Worker // call this function explicitly upon freeing memory while 68*da0073e9SAndroid Build Coastguard Worker // outside the scope of caching allocator. 69*da0073e9SAndroid Build Coastguard Worker // If the memory is freed in some other way, then we will likely 70*da0073e9SAndroid Build Coastguard Worker // have undefined behavior or page fault. But this can be 71*da0073e9SAndroid Build Coastguard Worker // the case without caching allocator as well. 72*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> guard(mutex_); 73*da0073e9SAndroid Build Coastguard Worker const auto& it = allocation_map_.find(ptr); 74*da0073e9SAndroid Build Coastguard Worker if (it != allocation_map_.end()) { 75*da0073e9SAndroid Build Coastguard Worker allocation_map_.erase(it); 76*da0073e9SAndroid Build Coastguard Worker } 77*da0073e9SAndroid Build Coastguard Worker } 78*da0073e9SAndroid Build Coastguard Worker free_cached()79*da0073e9SAndroid Build Coastguard Workervoid CPUCachingAllocator::free_cached() { 80*da0073e9SAndroid Build Coastguard Worker for (const auto& it : available_map_) { 81*da0073e9SAndroid Build Coastguard Worker for (const auto ptr : it.second) { 82*da0073e9SAndroid Build Coastguard Worker c10::free_cpu(ptr); 83*da0073e9SAndroid Build Coastguard Worker // When cached memory is return to OS, it must be removed 84*da0073e9SAndroid Build Coastguard Worker // from allocation_map. 85*da0073e9SAndroid Build Coastguard Worker allocation_map_.erase(ptr); 86*da0073e9SAndroid Build Coastguard Worker } 87*da0073e9SAndroid Build Coastguard Worker } 88*da0073e9SAndroid Build Coastguard Worker available_map_.clear(); 89*da0073e9SAndroid Build Coastguard Worker } 90*da0073e9SAndroid Build Coastguard Worker ~CPUCachingAllocator()91*da0073e9SAndroid Build Coastguard WorkerCPUCachingAllocator::~CPUCachingAllocator() { 92*da0073e9SAndroid Build Coastguard Worker free_cached(); 93*da0073e9SAndroid Build Coastguard Worker } 94*da0073e9SAndroid Build Coastguard Worker GetThreadLocalCachingAllocator()95*da0073e9SAndroid Build Coastguard WorkerCPUCachingAllocator* GetThreadLocalCachingAllocator() { 96*da0073e9SAndroid Build Coastguard Worker return caching_allocator_ptr; 97*da0073e9SAndroid Build Coastguard Worker } 98*da0073e9SAndroid Build Coastguard Worker WithCPUCachingAllocatorGuard(CPUCachingAllocator * allocator)99*da0073e9SAndroid Build Coastguard WorkerWithCPUCachingAllocatorGuard::WithCPUCachingAllocatorGuard( 100*da0073e9SAndroid Build Coastguard Worker CPUCachingAllocator* allocator) 101*da0073e9SAndroid Build Coastguard Worker : prev_caching_allocator_ptr_(GetThreadLocalCachingAllocator()) { 102*da0073e9SAndroid Build Coastguard Worker caching_allocator_ptr = allocator; 103*da0073e9SAndroid Build Coastguard Worker } 104*da0073e9SAndroid Build Coastguard Worker ~WithCPUCachingAllocatorGuard()105*da0073e9SAndroid Build Coastguard WorkerWithCPUCachingAllocatorGuard::~WithCPUCachingAllocatorGuard() { 106*da0073e9SAndroid Build Coastguard Worker caching_allocator_ptr = prev_caching_allocator_ptr_; 107*da0073e9SAndroid Build Coastguard Worker } 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker } // namespace c10 110