#include #include #include #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #include #endif namespace c10::cuda::CUDACachingAllocator { constexpr size_t kRoundUpPowerOfTwoIntervals = 16; CUDAAllocatorConfig::CUDAAllocatorConfig() : m_max_split_size(std::numeric_limits::max()), m_garbage_collection_threshold(0), m_pinned_num_register_threads(1), m_expandable_segments(false), m_release_lock_on_cudamalloc(false), m_pinned_use_cuda_host_register(false), m_last_allocator_settings("") { m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); } size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) { size_t log_size = (63 - llvm::countLeadingZeros(size)); // Our intervals start at 1MB and end at 64GB const size_t interval_start = 63 - llvm::countLeadingZeros(static_cast(1048576)); const size_t interval_end = 63 - llvm::countLeadingZeros(static_cast(68719476736)); TORCH_CHECK( (interval_end - interval_start == kRoundUpPowerOfTwoIntervals), "kRoundUpPowerOfTwoIntervals mismatch"); int index = static_cast(log_size) - static_cast(interval_start); index = std::max(0, index); index = std::min(index, static_cast(kRoundUpPowerOfTwoIntervals) - 1); return instance().m_roundup_power2_divisions[index]; } void CUDAAllocatorConfig::lexArgs( const char* env, std::vector& config) { std::vector buf; size_t env_length = strlen(env); for (size_t i = 0; i < env_length; i++) { if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') { if (!buf.empty()) { config.emplace_back(buf.begin(), buf.end()); buf.clear(); } config.emplace_back(1, env[i]); } else if (env[i] != ' ') { buf.emplace_back(static_cast(env[i])); } } if (!buf.empty()) { config.emplace_back(buf.begin(), buf.end()); } } void CUDAAllocatorConfig::consumeToken( const std::vector& config, size_t i, const char c) { TORCH_CHECK( i < config.size() && config[i] == std::string(1, c), "Error parsing CachingAllocator settings, expected ", c, ""); } size_t CUDAAllocatorConfig::parseMaxSplitSize( const std::vector& config, size_t i) { consumeToken(config, ++i, ':'); constexpr int mb = 1024 * 1024; if (++i < config.size()) { size_t val1 = stoi(config[i]); TORCH_CHECK( val1 > kLargeBuffer / mb, "CachingAllocator option max_split_size_mb too small, must be > ", kLargeBuffer / mb, ""); val1 = std::max(val1, kLargeBuffer / mb); val1 = std::min(val1, (std::numeric_limits::max() / mb)); m_max_split_size = val1 * 1024 * 1024; } else { TORCH_CHECK(false, "Error, expecting max_split_size_mb value", ""); } return i; } size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold( const std::vector& config, size_t i) { consumeToken(config, ++i, ':'); if (++i < config.size()) { double val1 = stod(config[i]); TORCH_CHECK( val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", ""); TORCH_CHECK( val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", ""); m_garbage_collection_threshold = val1; } else { TORCH_CHECK( false, "Error, expecting garbage_collection_threshold value", ""); } return i; } size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions( const std::vector& config, size_t i) { consumeToken(config, ++i, ':'); bool first_value = true; if (++i < config.size()) { if (std::string_view(config[i]) == "[") { size_t last_index = 0; while (++i < config.size() && std::string_view(config[i]) != "]") { const std::string& val1 = config[i]; size_t val2 = 0; consumeToken(config, ++i, ':'); if (++i < config.size()) { val2 = stoi(config[i]); } else { TORCH_CHECK( false, "Error parsing roundup_power2_divisions value", ""); } TORCH_CHECK( val2 == 0 || llvm::isPowerOf2_64(val2), "For roundups, the divisons has to be power of 2 or 0 to disable roundup ", ""); if (std::string_view(val1) == ">") { std::fill( std::next( m_roundup_power2_divisions.begin(), static_cast::difference_type>( last_index)), m_roundup_power2_divisions.end(), val2); } else { size_t val1_long = stoul(val1); TORCH_CHECK( llvm::isPowerOf2_64(val1_long), "For roundups, the intervals have to be power of 2 ", ""); size_t index = 63 - llvm::countLeadingZeros(val1_long); index = std::max((size_t)0, index); index = std::min(index, m_roundup_power2_divisions.size() - 1); if (first_value) { std::fill( m_roundup_power2_divisions.begin(), std::next( m_roundup_power2_divisions.begin(), static_cast::difference_type>( index)), val2); first_value = false; } if (index < m_roundup_power2_divisions.size()) { m_roundup_power2_divisions[index] = val2; } last_index = index; } if (std::string_view(config[i + 1]) != "]") { consumeToken(config, ++i, ','); } } } else { // Keep this for backwards compatibility size_t val1 = stoi(config[i]); TORCH_CHECK( llvm::isPowerOf2_64(val1), "For roundups, the divisons has to be power of 2 ", ""); std::fill( m_roundup_power2_divisions.begin(), m_roundup_power2_divisions.end(), val1); } } else { TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", ""); } return i; } size_t CUDAAllocatorConfig::parseAllocatorConfig( const std::vector& config, size_t i, bool& used_cudaMallocAsync) { consumeToken(config, ++i, ':'); if (++i < config.size()) { TORCH_CHECK( ((config[i] == "native") || (config[i] == "cudaMallocAsync")), "Unknown allocator backend, " "options are native and cudaMallocAsync"); used_cudaMallocAsync = (config[i] == "cudaMallocAsync"); #ifndef USE_ROCM // HIP supports hipMallocAsync and does not need to check versions if (used_cudaMallocAsync) { #if CUDA_VERSION >= 11040 int version = 0; C10_CUDA_CHECK(cudaDriverGetVersion(&version)); TORCH_CHECK( version >= 11040, "backend:cudaMallocAsync requires CUDA runtime " "11.4 or newer, but cudaDriverGetVersion returned ", version); #else TORCH_CHECK( false, "backend:cudaMallocAsync requires PyTorch to be built with " "CUDA 11.4 or newer, but CUDA_VERSION is ", CUDA_VERSION); #endif } #endif TORCH_INTERNAL_ASSERT( config[i] == get()->name(), "Allocator backend parsed at runtime != " "allocator backend parsed at load time"); } else { TORCH_CHECK(false, "Error parsing backend value", ""); } return i; } void CUDAAllocatorConfig::parseArgs(const char* env) { // If empty, set the default values m_max_split_size = std::numeric_limits::max(); m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); m_garbage_collection_threshold = 0; bool used_cudaMallocAsync = false; bool used_native_specific_option = false; if (env == nullptr) { return; } { std::lock_guard lock(m_last_allocator_settings_mutex); m_last_allocator_settings = env; } std::vector config; lexArgs(env, config); for (size_t i = 0; i < config.size(); i++) { std::string_view config_item_view(config[i]); if (config_item_view == "max_split_size_mb") { i = parseMaxSplitSize(config, i); used_native_specific_option = true; } else if (config_item_view == "garbage_collection_threshold") { i = parseGarbageCollectionThreshold(config, i); used_native_specific_option = true; } else if (config_item_view == "roundup_power2_divisions") { i = parseRoundUpPower2Divisions(config, i); used_native_specific_option = true; } else if (config_item_view == "backend") { i = parseAllocatorConfig(config, i, used_cudaMallocAsync); } else if (config_item_view == "expandable_segments") { used_native_specific_option = true; consumeToken(config, ++i, ':'); ++i; TORCH_CHECK( i < config.size() && (std::string_view(config[i]) == "True" || std::string_view(config[i]) == "False"), "Expected a single True/False argument for expandable_segments"); config_item_view = config[i]; m_expandable_segments = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. config_item_view == "release_lock_on_hipmalloc" || config_item_view == "release_lock_on_c" "udamalloc") { used_native_specific_option = true; consumeToken(config, ++i, ':'); ++i; TORCH_CHECK( i < config.size() && (std::string_view(config[i]) == "True" || std::string_view(config[i]) == "False"), "Expected a single True/False argument for release_lock_on_cudamalloc"); config_item_view = config[i]; m_release_lock_on_cudamalloc = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. config_item_view == "pinned_use_hip_host_register" || config_item_view == "pinned_use_c" "uda_host_register") { i = parsePinnedUseCudaHostRegister(config, i); used_native_specific_option = true; } else if (config_item_view == "pinned_num_register_threads") { i = parsePinnedNumRegisterThreads(config, i); used_native_specific_option = true; } else { TORCH_CHECK( false, "Unrecognized CachingAllocator option: ", config_item_view); } if (i + 1 < config.size()) { consumeToken(config, ++i, ','); } } if (used_cudaMallocAsync && used_native_specific_option) { TORCH_WARN( "backend:cudaMallocAsync ignores max_split_size_mb," "roundup_power2_divisions, and garbage_collect_threshold."); } } size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister( const std::vector& config, size_t i) { consumeToken(config, ++i, ':'); if (++i < config.size()) { TORCH_CHECK( (config[i] == "True" || config[i] == "False"), "Expected a single True/False argument for pinned_use_cuda_host_register"); m_pinned_use_cuda_host_register = (config[i] == "True"); } else { TORCH_CHECK( false, "Error, expecting pinned_use_cuda_host_register value", ""); } return i; } size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( const std::vector& config, size_t i) { consumeToken(config, ++i, ':'); if (++i < config.size()) { size_t val2 = stoi(config[i]); TORCH_CHECK( llvm::isPowerOf2_64(val2), "Number of register threads has to be power of 2 ", ""); auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); TORCH_CHECK( val2 <= maxThreads, "Number of register threads should be less than or equal to " + std::to_string(maxThreads), ""); m_pinned_num_register_threads = val2; } else { TORCH_CHECK( false, "Error, expecting pinned_num_register_threads value", ""); } return i; } // General caching allocator utilities void setAllocatorSettings(const std::string& env) { CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); } } // namespace c10::cuda::CUDACachingAllocator