xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAAllocatorConfig.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <atomic>
7*da0073e9SAndroid Build Coastguard Worker #include <cstddef>
8*da0073e9SAndroid Build Coastguard Worker #include <cstdlib>
9*da0073e9SAndroid Build Coastguard Worker #include <mutex>
10*da0073e9SAndroid Build Coastguard Worker #include <string>
11*da0073e9SAndroid Build Coastguard Worker #include <vector>
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda::CUDACachingAllocator {
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker // Environment config parser
16*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAAllocatorConfig {
17*da0073e9SAndroid Build Coastguard Worker  public:
max_split_size()18*da0073e9SAndroid Build Coastguard Worker   static size_t max_split_size() {
19*da0073e9SAndroid Build Coastguard Worker     return instance().m_max_split_size;
20*da0073e9SAndroid Build Coastguard Worker   }
garbage_collection_threshold()21*da0073e9SAndroid Build Coastguard Worker   static double garbage_collection_threshold() {
22*da0073e9SAndroid Build Coastguard Worker     return instance().m_garbage_collection_threshold;
23*da0073e9SAndroid Build Coastguard Worker   }
24*da0073e9SAndroid Build Coastguard Worker 
expandable_segments()25*da0073e9SAndroid Build Coastguard Worker   static bool expandable_segments() {
26*da0073e9SAndroid Build Coastguard Worker #ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
27*da0073e9SAndroid Build Coastguard Worker     if (instance().m_expandable_segments) {
28*da0073e9SAndroid Build Coastguard Worker       TORCH_WARN_ONCE("expandable_segments not supported on this platform")
29*da0073e9SAndroid Build Coastguard Worker     }
30*da0073e9SAndroid Build Coastguard Worker     return false;
31*da0073e9SAndroid Build Coastguard Worker #else
32*da0073e9SAndroid Build Coastguard Worker     return instance().m_expandable_segments;
33*da0073e9SAndroid Build Coastguard Worker #endif
34*da0073e9SAndroid Build Coastguard Worker   }
35*da0073e9SAndroid Build Coastguard Worker 
release_lock_on_cudamalloc()36*da0073e9SAndroid Build Coastguard Worker   static bool release_lock_on_cudamalloc() {
37*da0073e9SAndroid Build Coastguard Worker     return instance().m_release_lock_on_cudamalloc;
38*da0073e9SAndroid Build Coastguard Worker   }
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker   /** Pinned memory allocator settings */
pinned_use_cuda_host_register()41*da0073e9SAndroid Build Coastguard Worker   static bool pinned_use_cuda_host_register() {
42*da0073e9SAndroid Build Coastguard Worker     return instance().m_pinned_use_cuda_host_register;
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker 
pinned_num_register_threads()45*da0073e9SAndroid Build Coastguard Worker   static size_t pinned_num_register_threads() {
46*da0073e9SAndroid Build Coastguard Worker     return instance().m_pinned_num_register_threads;
47*da0073e9SAndroid Build Coastguard Worker   }
48*da0073e9SAndroid Build Coastguard Worker 
pinned_max_register_threads()49*da0073e9SAndroid Build Coastguard Worker   static size_t pinned_max_register_threads() {
50*da0073e9SAndroid Build Coastguard Worker     // Based on the benchmark results, we see better allocation performance
51*da0073e9SAndroid Build Coastguard Worker     // with 8 threads. However on future systems, we may need more threads
52*da0073e9SAndroid Build Coastguard Worker     // and limiting this to 128 threads.
53*da0073e9SAndroid Build Coastguard Worker     return 128;
54*da0073e9SAndroid Build Coastguard Worker   }
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker   // This is used to round-up allocation size to nearest power of 2 divisions.
57*da0073e9SAndroid Build Coastguard Worker   // More description below in function roundup_power2_next_division
58*da0073e9SAndroid Build Coastguard Worker   // As ane example, if we want 4 divisions between 2's power, this can be done
59*da0073e9SAndroid Build Coastguard Worker   // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
60*da0073e9SAndroid Build Coastguard Worker   static size_t roundup_power2_divisions(size_t size);
61*da0073e9SAndroid Build Coastguard Worker 
roundup_power2_divisions()62*da0073e9SAndroid Build Coastguard Worker   static std::vector<size_t> roundup_power2_divisions() {
63*da0073e9SAndroid Build Coastguard Worker     return instance().m_roundup_power2_divisions;
64*da0073e9SAndroid Build Coastguard Worker   }
65*da0073e9SAndroid Build Coastguard Worker 
last_allocator_settings()66*da0073e9SAndroid Build Coastguard Worker   static std::string last_allocator_settings() {
67*da0073e9SAndroid Build Coastguard Worker     std::lock_guard<std::mutex> lock(
68*da0073e9SAndroid Build Coastguard Worker         instance().m_last_allocator_settings_mutex);
69*da0073e9SAndroid Build Coastguard Worker     return instance().m_last_allocator_settings;
70*da0073e9SAndroid Build Coastguard Worker   }
71*da0073e9SAndroid Build Coastguard Worker 
instance()72*da0073e9SAndroid Build Coastguard Worker   static CUDAAllocatorConfig& instance() {
73*da0073e9SAndroid Build Coastguard Worker     static CUDAAllocatorConfig* s_instance = ([]() {
74*da0073e9SAndroid Build Coastguard Worker       auto inst = new CUDAAllocatorConfig();
75*da0073e9SAndroid Build Coastguard Worker       const char* env = getenv("PYTORCH_CUDA_ALLOC_CONF");
76*da0073e9SAndroid Build Coastguard Worker       inst->parseArgs(env);
77*da0073e9SAndroid Build Coastguard Worker       return inst;
78*da0073e9SAndroid Build Coastguard Worker     })();
79*da0073e9SAndroid Build Coastguard Worker     return *s_instance;
80*da0073e9SAndroid Build Coastguard Worker   }
81*da0073e9SAndroid Build Coastguard Worker 
82*da0073e9SAndroid Build Coastguard Worker   void parseArgs(const char* env);
83*da0073e9SAndroid Build Coastguard Worker 
84*da0073e9SAndroid Build Coastguard Worker  private:
85*da0073e9SAndroid Build Coastguard Worker   CUDAAllocatorConfig();
86*da0073e9SAndroid Build Coastguard Worker 
87*da0073e9SAndroid Build Coastguard Worker   static void lexArgs(const char* env, std::vector<std::string>& config);
88*da0073e9SAndroid Build Coastguard Worker   static void consumeToken(
89*da0073e9SAndroid Build Coastguard Worker       const std::vector<std::string>& config,
90*da0073e9SAndroid Build Coastguard Worker       size_t i,
91*da0073e9SAndroid Build Coastguard Worker       const char c);
92*da0073e9SAndroid Build Coastguard Worker   size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
93*da0073e9SAndroid Build Coastguard Worker   size_t parseGarbageCollectionThreshold(
94*da0073e9SAndroid Build Coastguard Worker       const std::vector<std::string>& config,
95*da0073e9SAndroid Build Coastguard Worker       size_t i);
96*da0073e9SAndroid Build Coastguard Worker   size_t parseRoundUpPower2Divisions(
97*da0073e9SAndroid Build Coastguard Worker       const std::vector<std::string>& config,
98*da0073e9SAndroid Build Coastguard Worker       size_t i);
99*da0073e9SAndroid Build Coastguard Worker   size_t parseAllocatorConfig(
100*da0073e9SAndroid Build Coastguard Worker       const std::vector<std::string>& config,
101*da0073e9SAndroid Build Coastguard Worker       size_t i,
102*da0073e9SAndroid Build Coastguard Worker       bool& used_cudaMallocAsync);
103*da0073e9SAndroid Build Coastguard Worker   size_t parsePinnedUseCudaHostRegister(
104*da0073e9SAndroid Build Coastguard Worker       const std::vector<std::string>& config,
105*da0073e9SAndroid Build Coastguard Worker       size_t i);
106*da0073e9SAndroid Build Coastguard Worker   size_t parsePinnedNumRegisterThreads(
107*da0073e9SAndroid Build Coastguard Worker       const std::vector<std::string>& config,
108*da0073e9SAndroid Build Coastguard Worker       size_t i);
109*da0073e9SAndroid Build Coastguard Worker 
110*da0073e9SAndroid Build Coastguard Worker   std::atomic<size_t> m_max_split_size;
111*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> m_roundup_power2_divisions;
112*da0073e9SAndroid Build Coastguard Worker   std::atomic<double> m_garbage_collection_threshold;
113*da0073e9SAndroid Build Coastguard Worker   std::atomic<size_t> m_pinned_num_register_threads;
114*da0073e9SAndroid Build Coastguard Worker   std::atomic<bool> m_expandable_segments;
115*da0073e9SAndroid Build Coastguard Worker   std::atomic<bool> m_release_lock_on_cudamalloc;
116*da0073e9SAndroid Build Coastguard Worker   std::atomic<bool> m_pinned_use_cuda_host_register;
117*da0073e9SAndroid Build Coastguard Worker   std::string m_last_allocator_settings;
118*da0073e9SAndroid Build Coastguard Worker   std::mutex m_last_allocator_settings_mutex;
119*da0073e9SAndroid Build Coastguard Worker };
120*da0073e9SAndroid Build Coastguard Worker 
121*da0073e9SAndroid Build Coastguard Worker // General caching allocator utilities
122*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void setAllocatorSettings(const std::string& env);
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda::CUDACachingAllocator
125