1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h> 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker #include <atomic> 7*da0073e9SAndroid Build Coastguard Worker #include <cstddef> 8*da0073e9SAndroid Build Coastguard Worker #include <cstdlib> 9*da0073e9SAndroid Build Coastguard Worker #include <mutex> 10*da0073e9SAndroid Build Coastguard Worker #include <string> 11*da0073e9SAndroid Build Coastguard Worker #include <vector> 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda::CUDACachingAllocator { 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker // Environment config parser 16*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAAllocatorConfig { 17*da0073e9SAndroid Build Coastguard Worker public: max_split_size()18*da0073e9SAndroid Build Coastguard Worker static size_t max_split_size() { 19*da0073e9SAndroid Build Coastguard Worker return instance().m_max_split_size; 20*da0073e9SAndroid Build Coastguard Worker } garbage_collection_threshold()21*da0073e9SAndroid Build Coastguard Worker static double garbage_collection_threshold() { 22*da0073e9SAndroid Build Coastguard Worker return instance().m_garbage_collection_threshold; 23*da0073e9SAndroid Build Coastguard Worker } 24*da0073e9SAndroid Build Coastguard Worker expandable_segments()25*da0073e9SAndroid Build Coastguard Worker static bool expandable_segments() { 26*da0073e9SAndroid Build Coastguard Worker #ifndef PYTORCH_C10_DRIVER_API_SUPPORTED 27*da0073e9SAndroid Build Coastguard Worker if (instance().m_expandable_segments) { 28*da0073e9SAndroid Build Coastguard Worker TORCH_WARN_ONCE("expandable_segments not supported on this platform") 29*da0073e9SAndroid Build Coastguard Worker } 30*da0073e9SAndroid Build Coastguard Worker return false; 31*da0073e9SAndroid Build Coastguard Worker #else 32*da0073e9SAndroid Build Coastguard Worker return instance().m_expandable_segments; 33*da0073e9SAndroid Build Coastguard Worker #endif 34*da0073e9SAndroid Build Coastguard Worker } 35*da0073e9SAndroid Build Coastguard Worker release_lock_on_cudamalloc()36*da0073e9SAndroid Build Coastguard Worker static bool release_lock_on_cudamalloc() { 37*da0073e9SAndroid Build Coastguard Worker return instance().m_release_lock_on_cudamalloc; 38*da0073e9SAndroid Build Coastguard Worker } 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker /** Pinned memory allocator settings */ pinned_use_cuda_host_register()41*da0073e9SAndroid Build Coastguard Worker static bool pinned_use_cuda_host_register() { 42*da0073e9SAndroid Build Coastguard Worker return instance().m_pinned_use_cuda_host_register; 43*da0073e9SAndroid Build Coastguard Worker } 44*da0073e9SAndroid Build Coastguard Worker pinned_num_register_threads()45*da0073e9SAndroid Build Coastguard Worker static size_t pinned_num_register_threads() { 46*da0073e9SAndroid Build Coastguard Worker return instance().m_pinned_num_register_threads; 47*da0073e9SAndroid Build Coastguard Worker } 48*da0073e9SAndroid Build Coastguard Worker pinned_max_register_threads()49*da0073e9SAndroid Build Coastguard Worker static size_t pinned_max_register_threads() { 50*da0073e9SAndroid Build Coastguard Worker // Based on the benchmark results, we see better allocation performance 51*da0073e9SAndroid Build Coastguard Worker // with 8 threads. However on future systems, we may need more threads 52*da0073e9SAndroid Build Coastguard Worker // and limiting this to 128 threads. 53*da0073e9SAndroid Build Coastguard Worker return 128; 54*da0073e9SAndroid Build Coastguard Worker } 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker // This is used to round-up allocation size to nearest power of 2 divisions. 57*da0073e9SAndroid Build Coastguard Worker // More description below in function roundup_power2_next_division 58*da0073e9SAndroid Build Coastguard Worker // As ane example, if we want 4 divisions between 2's power, this can be done 59*da0073e9SAndroid Build Coastguard Worker // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 60*da0073e9SAndroid Build Coastguard Worker static size_t roundup_power2_divisions(size_t size); 61*da0073e9SAndroid Build Coastguard Worker roundup_power2_divisions()62*da0073e9SAndroid Build Coastguard Worker static std::vector<size_t> roundup_power2_divisions() { 63*da0073e9SAndroid Build Coastguard Worker return instance().m_roundup_power2_divisions; 64*da0073e9SAndroid Build Coastguard Worker } 65*da0073e9SAndroid Build Coastguard Worker last_allocator_settings()66*da0073e9SAndroid Build Coastguard Worker static std::string last_allocator_settings() { 67*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock( 68*da0073e9SAndroid Build Coastguard Worker instance().m_last_allocator_settings_mutex); 69*da0073e9SAndroid Build Coastguard Worker return instance().m_last_allocator_settings; 70*da0073e9SAndroid Build Coastguard Worker } 71*da0073e9SAndroid Build Coastguard Worker instance()72*da0073e9SAndroid Build Coastguard Worker static CUDAAllocatorConfig& instance() { 73*da0073e9SAndroid Build Coastguard Worker static CUDAAllocatorConfig* s_instance = ([]() { 74*da0073e9SAndroid Build Coastguard Worker auto inst = new CUDAAllocatorConfig(); 75*da0073e9SAndroid Build Coastguard Worker const char* env = getenv("PYTORCH_CUDA_ALLOC_CONF"); 76*da0073e9SAndroid Build Coastguard Worker inst->parseArgs(env); 77*da0073e9SAndroid Build Coastguard Worker return inst; 78*da0073e9SAndroid Build Coastguard Worker })(); 79*da0073e9SAndroid Build Coastguard Worker return *s_instance; 80*da0073e9SAndroid Build Coastguard Worker } 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker void parseArgs(const char* env); 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker private: 85*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig(); 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker static void lexArgs(const char* env, std::vector<std::string>& config); 88*da0073e9SAndroid Build Coastguard Worker static void consumeToken( 89*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config, 90*da0073e9SAndroid Build Coastguard Worker size_t i, 91*da0073e9SAndroid Build Coastguard Worker const char c); 92*da0073e9SAndroid Build Coastguard Worker size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i); 93*da0073e9SAndroid Build Coastguard Worker size_t parseGarbageCollectionThreshold( 94*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config, 95*da0073e9SAndroid Build Coastguard Worker size_t i); 96*da0073e9SAndroid Build Coastguard Worker size_t parseRoundUpPower2Divisions( 97*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config, 98*da0073e9SAndroid Build Coastguard Worker size_t i); 99*da0073e9SAndroid Build Coastguard Worker size_t parseAllocatorConfig( 100*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config, 101*da0073e9SAndroid Build Coastguard Worker size_t i, 102*da0073e9SAndroid Build Coastguard Worker bool& used_cudaMallocAsync); 103*da0073e9SAndroid Build Coastguard Worker size_t parsePinnedUseCudaHostRegister( 104*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config, 105*da0073e9SAndroid Build Coastguard Worker size_t i); 106*da0073e9SAndroid Build Coastguard Worker size_t parsePinnedNumRegisterThreads( 107*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config, 108*da0073e9SAndroid Build Coastguard Worker size_t i); 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker std::atomic<size_t> m_max_split_size; 111*da0073e9SAndroid Build Coastguard Worker std::vector<size_t> m_roundup_power2_divisions; 112*da0073e9SAndroid Build Coastguard Worker std::atomic<double> m_garbage_collection_threshold; 113*da0073e9SAndroid Build Coastguard Worker std::atomic<size_t> m_pinned_num_register_threads; 114*da0073e9SAndroid Build Coastguard Worker std::atomic<bool> m_expandable_segments; 115*da0073e9SAndroid Build Coastguard Worker std::atomic<bool> m_release_lock_on_cudamalloc; 116*da0073e9SAndroid Build Coastguard Worker std::atomic<bool> m_pinned_use_cuda_host_register; 117*da0073e9SAndroid Build Coastguard Worker std::string m_last_allocator_settings; 118*da0073e9SAndroid Build Coastguard Worker std::mutex m_last_allocator_settings_mutex; 119*da0073e9SAndroid Build Coastguard Worker }; 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker // General caching allocator utilities 122*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void setAllocatorSettings(const std::string& env); 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda::CUDACachingAllocator 125