1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOCASYNC_ALLOCATOR_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOCASYNC_ALLOCATOR_H_ 18 19 #include <memory> 20 21 #if GOOGLE_CUDA 22 #include "third_party/gpus/cuda/include/cuda.h" 23 #endif // GOOGLE_CUDA 24 25 #include "absl/container/flat_hash_map.h" 26 #include "tensorflow/core/common_runtime/gpu/gpu_id.h" 27 #include "tensorflow/core/framework/allocator.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/platform/stream_executor.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace tensorflow { 33 34 #if GOOGLE_CUDA 35 #define TF_CUDA_MALLOC_ASYNC_SUPPORTED CUDA_VERSION >= 11020 36 #endif 37 38 // An allocator that wraps cudaMallocAsync. It has fewer fragmentation 39 // issues then the BFC memory allocator. The compute-sanitizer tool 40 // helps to detect OOB memory errors when using cudaMallocAsync. Use 41 // the environment variable `TF_GPU_ALLOCATOR=cuda_malloc_async` to 42 // enable it. 43 // 44 // It needs CUDA 11.2+. When using a container, this only needs the 45 // container driver to be 11.2. It has a WAR again a driver bug in 46 // multi-GPU setup with CUDA 11.2. The WAR creates an extra context on 47 // GPU 0. 48 // 49 // We configure cudaMallocAsync to grow when more memory is needed 50 // instead of preallocating everything up front and to keep a local 51 // pool up to pool_size bytes that is never released to other processes. 52 // So no other process will "steal" the GPU memory already used by the 53 // current process. This is to speed up execution and prevent crashes 54 // of long-running jobs. Use `reserve_memory=true` if you want to 55 // preallocate the full pool_size. You can also use the environment 56 // variable `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=nb_bytes` to preallocate 57 // that amount of memory. `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=-1` is a 58 // special value that preallocate all what the BFC memory allocator 59 // would have allocated. This is useful when benchmarking as it doesn't 60 // change when driver allocations are done. 61 // 62 // Here, the pool_size isn't the absolute max as for [Gpu]BFCAllocator. 63 // The pool can grow above that up to the total GPU memory. But the 64 // driver can return the excess memory to other processes. 65 class GpuCudaMallocAsyncAllocator : public Allocator { 66 public: 67 explicit GpuCudaMallocAsyncAllocator(PlatformDeviceId platform_device_id, 68 size_t pool_size, 69 bool reserve_memory = false, 70 bool compute_stats = true); 71 ~GpuCudaMallocAsyncAllocator() override; Name()72 string Name() override { return name_; } 73 void* AllocateRaw(size_t alignment, size_t num_bytes) override; 74 void DeallocateRaw(void* ptr) override; 75 76 bool TracksAllocationSizes() const override; 77 78 size_t RequestedSize(const void* ptr) const override; 79 80 size_t AllocatedSize(const void* ptr) const override; 81 82 absl::optional<AllocatorStats> GetStats() override; 83 84 bool ClearStats() override; 85 86 void SetStreamAndPreallocateMemory(void* stream) override; 87 88 // With the right VLOG set, it prints: 89 // - the number of ptr currently allocated per size (histogram). 90 // - each ptr value and its size. 91 // - If CUDA_VERSION >= 11030, print cudaMallocAsync statistics. 92 void PrintAllocatorStatistics(); 93 GetInstantiatedCountTestOnly()94 static int GetInstantiatedCountTestOnly() { return number_instantiated_; } 95 GetMemoryType()96 AllocatorMemoryType GetMemoryType() const override { 97 return AllocatorMemoryType::kDevice; 98 } 99 100 private: 101 #if TF_CUDA_MALLOC_ASYNC_SUPPORTED 102 se::StreamExecutor* stream_exec_; // Not owned. 103 104 // cudaMallocAsync is stream aware. But TF StreamExecutor use only 1 105 // compute stream and already synchronize with the h2d, d2h and d2d 106 // stream. So we do not need to ask cudaMallocAsync to add extra 107 // synchronization. 108 // Not owned. 109 CUstream cuda_stream_; 110 111 // Not owned. The default pool of the associated GPU. 112 // If null, then the instanciation failed and the first allocation 113 // will return an error. 114 CUmemoryPool pool_; 115 #endif // TF_CUDA_MALLOC_ASYNC_SUPPORTED 116 117 // Just a counter for the number of time this class is instantiated. 118 // Only useful for tests. 119 static std::atomic<int> number_instantiated_; 120 121 string name_; 122 123 bool reserve_memory_; 124 125 TF_DISALLOW_COPY_AND_ASSIGN(GpuCudaMallocAsyncAllocator); 126 127 // Stats. 128 // Structures mutable after construction 129 mutable mutex lock_; 130 std::unique_ptr<AllocatorStats> stats_ TF_PT_GUARDED_BY(lock_); 131 absl::flat_hash_map<const void*, size_t> size_map_ TF_GUARDED_BY(lock_); 132 }; 133 134 } // namespace tensorflow 135 136 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOCASYNC_ALLOCATOR_H_ 137