1 #include <c10/cuda/CUDAAllocatorConfig.h>
2 #include <c10/cuda/CUDACachingAllocator.h>
3 #include <c10/util/llvmMathExtras.h>
4
5 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
6 #include <c10/cuda/driver_api.h>
7 #endif
8
9 namespace c10::cuda::CUDACachingAllocator {
10
11 constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
12
CUDAAllocatorConfig()13 CUDAAllocatorConfig::CUDAAllocatorConfig()
14 : m_max_split_size(std::numeric_limits<size_t>::max()),
15 m_garbage_collection_threshold(0),
16 m_pinned_num_register_threads(1),
17 m_expandable_segments(false),
18 m_release_lock_on_cudamalloc(false),
19 m_pinned_use_cuda_host_register(false),
20 m_last_allocator_settings("") {
21 m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
22 }
23
roundup_power2_divisions(size_t size)24 size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
25 size_t log_size = (63 - llvm::countLeadingZeros(size));
26
27 // Our intervals start at 1MB and end at 64GB
28 const size_t interval_start =
29 63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
30 const size_t interval_end =
31 63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
32 TORCH_CHECK(
33 (interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
34 "kRoundUpPowerOfTwoIntervals mismatch");
35
36 int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
37
38 index = std::max(0, index);
39 index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
40 return instance().m_roundup_power2_divisions[index];
41 }
42
lexArgs(const char * env,std::vector<std::string> & config)43 void CUDAAllocatorConfig::lexArgs(
44 const char* env,
45 std::vector<std::string>& config) {
46 std::vector<char> buf;
47
48 size_t env_length = strlen(env);
49 for (size_t i = 0; i < env_length; i++) {
50 if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') {
51 if (!buf.empty()) {
52 config.emplace_back(buf.begin(), buf.end());
53 buf.clear();
54 }
55 config.emplace_back(1, env[i]);
56 } else if (env[i] != ' ') {
57 buf.emplace_back(static_cast<char>(env[i]));
58 }
59 }
60 if (!buf.empty()) {
61 config.emplace_back(buf.begin(), buf.end());
62 }
63 }
64
consumeToken(const std::vector<std::string> & config,size_t i,const char c)65 void CUDAAllocatorConfig::consumeToken(
66 const std::vector<std::string>& config,
67 size_t i,
68 const char c) {
69 TORCH_CHECK(
70 i < config.size() && config[i] == std::string(1, c),
71 "Error parsing CachingAllocator settings, expected ",
72 c,
73 "");
74 }
75
parseMaxSplitSize(const std::vector<std::string> & config,size_t i)76 size_t CUDAAllocatorConfig::parseMaxSplitSize(
77 const std::vector<std::string>& config,
78 size_t i) {
79 consumeToken(config, ++i, ':');
80 constexpr int mb = 1024 * 1024;
81 if (++i < config.size()) {
82 size_t val1 = stoi(config[i]);
83 TORCH_CHECK(
84 val1 > kLargeBuffer / mb,
85 "CachingAllocator option max_split_size_mb too small, must be > ",
86 kLargeBuffer / mb,
87 "");
88 val1 = std::max(val1, kLargeBuffer / mb);
89 val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
90 m_max_split_size = val1 * 1024 * 1024;
91 } else {
92 TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
93 }
94 return i;
95 }
96
parseGarbageCollectionThreshold(const std::vector<std::string> & config,size_t i)97 size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
98 const std::vector<std::string>& config,
99 size_t i) {
100 consumeToken(config, ++i, ':');
101 if (++i < config.size()) {
102 double val1 = stod(config[i]);
103 TORCH_CHECK(
104 val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
105 TORCH_CHECK(
106 val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
107 m_garbage_collection_threshold = val1;
108 } else {
109 TORCH_CHECK(
110 false, "Error, expecting garbage_collection_threshold value", "");
111 }
112 return i;
113 }
114
parseRoundUpPower2Divisions(const std::vector<std::string> & config,size_t i)115 size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
116 const std::vector<std::string>& config,
117 size_t i) {
118 consumeToken(config, ++i, ':');
119 bool first_value = true;
120
121 if (++i < config.size()) {
122 if (std::string_view(config[i]) == "[") {
123 size_t last_index = 0;
124 while (++i < config.size() && std::string_view(config[i]) != "]") {
125 const std::string& val1 = config[i];
126 size_t val2 = 0;
127
128 consumeToken(config, ++i, ':');
129 if (++i < config.size()) {
130 val2 = stoi(config[i]);
131 } else {
132 TORCH_CHECK(
133 false, "Error parsing roundup_power2_divisions value", "");
134 }
135 TORCH_CHECK(
136 val2 == 0 || llvm::isPowerOf2_64(val2),
137 "For roundups, the divisons has to be power of 2 or 0 to disable roundup ",
138 "");
139
140 if (std::string_view(val1) == ">") {
141 std::fill(
142 std::next(
143 m_roundup_power2_divisions.begin(),
144 static_cast<std::vector<unsigned long>::difference_type>(
145 last_index)),
146 m_roundup_power2_divisions.end(),
147 val2);
148 } else {
149 size_t val1_long = stoul(val1);
150 TORCH_CHECK(
151 llvm::isPowerOf2_64(val1_long),
152 "For roundups, the intervals have to be power of 2 ",
153 "");
154
155 size_t index = 63 - llvm::countLeadingZeros(val1_long);
156 index = std::max((size_t)0, index);
157 index = std::min(index, m_roundup_power2_divisions.size() - 1);
158
159 if (first_value) {
160 std::fill(
161 m_roundup_power2_divisions.begin(),
162 std::next(
163 m_roundup_power2_divisions.begin(),
164 static_cast<std::vector<unsigned long>::difference_type>(
165 index)),
166 val2);
167 first_value = false;
168 }
169 if (index < m_roundup_power2_divisions.size()) {
170 m_roundup_power2_divisions[index] = val2;
171 }
172 last_index = index;
173 }
174
175 if (std::string_view(config[i + 1]) != "]") {
176 consumeToken(config, ++i, ',');
177 }
178 }
179 } else { // Keep this for backwards compatibility
180 size_t val1 = stoi(config[i]);
181 TORCH_CHECK(
182 llvm::isPowerOf2_64(val1),
183 "For roundups, the divisons has to be power of 2 ",
184 "");
185 std::fill(
186 m_roundup_power2_divisions.begin(),
187 m_roundup_power2_divisions.end(),
188 val1);
189 }
190 } else {
191 TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
192 }
193 return i;
194 }
195
parseAllocatorConfig(const std::vector<std::string> & config,size_t i,bool & used_cudaMallocAsync)196 size_t CUDAAllocatorConfig::parseAllocatorConfig(
197 const std::vector<std::string>& config,
198 size_t i,
199 bool& used_cudaMallocAsync) {
200 consumeToken(config, ++i, ':');
201 if (++i < config.size()) {
202 TORCH_CHECK(
203 ((config[i] == "native") || (config[i] == "cudaMallocAsync")),
204 "Unknown allocator backend, "
205 "options are native and cudaMallocAsync");
206 used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
207 #ifndef USE_ROCM
208 // HIP supports hipMallocAsync and does not need to check versions
209 if (used_cudaMallocAsync) {
210 #if CUDA_VERSION >= 11040
211 int version = 0;
212 C10_CUDA_CHECK(cudaDriverGetVersion(&version));
213 TORCH_CHECK(
214 version >= 11040,
215 "backend:cudaMallocAsync requires CUDA runtime "
216 "11.4 or newer, but cudaDriverGetVersion returned ",
217 version);
218 #else
219 TORCH_CHECK(
220 false,
221 "backend:cudaMallocAsync requires PyTorch to be built with "
222 "CUDA 11.4 or newer, but CUDA_VERSION is ",
223 CUDA_VERSION);
224 #endif
225 }
226 #endif
227 TORCH_INTERNAL_ASSERT(
228 config[i] == get()->name(),
229 "Allocator backend parsed at runtime != "
230 "allocator backend parsed at load time");
231 } else {
232 TORCH_CHECK(false, "Error parsing backend value", "");
233 }
234 return i;
235 }
236
parseArgs(const char * env)237 void CUDAAllocatorConfig::parseArgs(const char* env) {
238 // If empty, set the default values
239 m_max_split_size = std::numeric_limits<size_t>::max();
240 m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
241 m_garbage_collection_threshold = 0;
242 bool used_cudaMallocAsync = false;
243 bool used_native_specific_option = false;
244
245 if (env == nullptr) {
246 return;
247 }
248 {
249 std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
250 m_last_allocator_settings = env;
251 }
252
253 std::vector<std::string> config;
254 lexArgs(env, config);
255
256 for (size_t i = 0; i < config.size(); i++) {
257 std::string_view config_item_view(config[i]);
258 if (config_item_view == "max_split_size_mb") {
259 i = parseMaxSplitSize(config, i);
260 used_native_specific_option = true;
261 } else if (config_item_view == "garbage_collection_threshold") {
262 i = parseGarbageCollectionThreshold(config, i);
263 used_native_specific_option = true;
264 } else if (config_item_view == "roundup_power2_divisions") {
265 i = parseRoundUpPower2Divisions(config, i);
266 used_native_specific_option = true;
267 } else if (config_item_view == "backend") {
268 i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
269 } else if (config_item_view == "expandable_segments") {
270 used_native_specific_option = true;
271 consumeToken(config, ++i, ':');
272 ++i;
273 TORCH_CHECK(
274 i < config.size() &&
275 (std::string_view(config[i]) == "True" ||
276 std::string_view(config[i]) == "False"),
277 "Expected a single True/False argument for expandable_segments");
278 config_item_view = config[i];
279 m_expandable_segments = (config_item_view == "True");
280 } else if (
281 // ROCm build's hipify step will change "cuda" to "hip", but for ease of
282 // use, accept both. We must break up the string to prevent hipify here.
283 config_item_view == "release_lock_on_hipmalloc" ||
284 config_item_view ==
285 "release_lock_on_c"
286 "udamalloc") {
287 used_native_specific_option = true;
288 consumeToken(config, ++i, ':');
289 ++i;
290 TORCH_CHECK(
291 i < config.size() &&
292 (std::string_view(config[i]) == "True" ||
293 std::string_view(config[i]) == "False"),
294 "Expected a single True/False argument for release_lock_on_cudamalloc");
295 config_item_view = config[i];
296 m_release_lock_on_cudamalloc = (config_item_view == "True");
297 } else if (
298 // ROCm build's hipify step will change "cuda" to "hip", but for ease of
299 // use, accept both. We must break up the string to prevent hipify here.
300 config_item_view == "pinned_use_hip_host_register" ||
301 config_item_view ==
302 "pinned_use_c"
303 "uda_host_register") {
304 i = parsePinnedUseCudaHostRegister(config, i);
305 used_native_specific_option = true;
306 } else if (config_item_view == "pinned_num_register_threads") {
307 i = parsePinnedNumRegisterThreads(config, i);
308 used_native_specific_option = true;
309 } else {
310 TORCH_CHECK(
311 false, "Unrecognized CachingAllocator option: ", config_item_view);
312 }
313
314 if (i + 1 < config.size()) {
315 consumeToken(config, ++i, ',');
316 }
317 }
318
319 if (used_cudaMallocAsync && used_native_specific_option) {
320 TORCH_WARN(
321 "backend:cudaMallocAsync ignores max_split_size_mb,"
322 "roundup_power2_divisions, and garbage_collect_threshold.");
323 }
324 }
325
parsePinnedUseCudaHostRegister(const std::vector<std::string> & config,size_t i)326 size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
327 const std::vector<std::string>& config,
328 size_t i) {
329 consumeToken(config, ++i, ':');
330 if (++i < config.size()) {
331 TORCH_CHECK(
332 (config[i] == "True" || config[i] == "False"),
333 "Expected a single True/False argument for pinned_use_cuda_host_register");
334 m_pinned_use_cuda_host_register = (config[i] == "True");
335 } else {
336 TORCH_CHECK(
337 false, "Error, expecting pinned_use_cuda_host_register value", "");
338 }
339 return i;
340 }
341
parsePinnedNumRegisterThreads(const std::vector<std::string> & config,size_t i)342 size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
343 const std::vector<std::string>& config,
344 size_t i) {
345 consumeToken(config, ++i, ':');
346 if (++i < config.size()) {
347 size_t val2 = stoi(config[i]);
348 TORCH_CHECK(
349 llvm::isPowerOf2_64(val2),
350 "Number of register threads has to be power of 2 ",
351 "");
352 auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
353 TORCH_CHECK(
354 val2 <= maxThreads,
355 "Number of register threads should be less than or equal to " +
356 std::to_string(maxThreads),
357 "");
358 m_pinned_num_register_threads = val2;
359 } else {
360 TORCH_CHECK(
361 false, "Error, expecting pinned_num_register_threads value", "");
362 }
363 return i;
364 }
365
366 // General caching allocator utilities
setAllocatorSettings(const std::string & env)367 void setAllocatorSettings(const std::string& env) {
368 CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
369 }
370
371 } // namespace c10::cuda::CUDACachingAllocator
372