1 // Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states. 2 // These handles are tied to device, and these libraries requires/recommends not to 3 // share handles across host threads. 4 // 5 // These libraries recommend using one handle per host thread. We may not want to do 6 // this because threads are relatively light-weight, but creating and destroying 7 // handles is expensive (destroying the handle causes synchronizations). DataParallel, 8 // for example, creates new threads for each forward pass. 9 // 10 // This file implements a handle pool mechanism. The handle pool returns handles on 11 // demand as threads request them. If all existing handles in the pool are in use, 12 // it creates a new one. As threads terminate, they release handles back into the pool. 13 // In this way, the handle pool never creates more handles than the high-water mark of 14 // active threads, so it's efficient with DataParallel. 15 16 #pragma once 17 18 #include <unordered_map> 19 #include <vector> 20 #include <utility> 21 #include <mutex> 22 #include <memory> 23 24 #include <c10/util/Exception.h> 25 26 namespace at::cuda { namespace { 27 28 template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)> 29 struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> { 30 31 struct Handle { 32 Handle_t handle; handleDeviceThreadHandlePool::Handle33 Handle(bool create = false) : handle(nullptr) 34 { 35 if(create) Create(&handle); 36 } 37 // std::vector.emplace() and push_back() may route through temporaries and call 38 // copy/move constructors along the way. If this is the case, we don't want 39 // the destructors of temporaries to call cudnnDestroy on the handle. 40 // We can achieve safety (for the narrow case of stashing within std::vectors) 41 // by making Handle moveable but not copyable, and transferring handle ownership 42 // to the latest constructed object. This is not a substitute for full-blown 43 // reference counting, but reference counting may be overkill here. 44 // Another alternative is to wrap the saved Handles in unique_ptrs, i.e., 45 // unordered_map<int, vector<unique_ptr<Handle>>> created_handles; 46 Handle(const Handle& rhs) = delete; 47 // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom HandleDeviceThreadHandlePool::Handle48 Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); } 49 // operator= takes argument by value 50 Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; } ~HandleDeviceThreadHandlePool::Handle51 ~Handle() { 52 if(handle) Destroy(handle); 53 } 54 }; 55 56 std::mutex mutex; 57 58 // Handles are lazily created as different threads request them, 59 // but are never destroyed until the end of the process. 60 // The maximum number of handles this process will create for each device is equal 61 // to the high-water mark of the number of concurrently active threads that request 62 // handles for that device. 63 // When threads terminate, they release their handles back into the pool for reuse. 64 // Otherwise, new handles would be created every time new threads were spawned, 65 // resulting in poor performance for Python modules that repeatedly or frequently 66 // spawned new sets of threads (like DataParallel, which creates a new set of threads 67 // for each forward pass). 68 // 69 // To prevent potential deadlocks, we explicitly choose not to cap the number 70 // of handles that are created per device. 71 // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device, 72 // only 4 can make forward progress at any time. The other 4 will not release their 73 // handles until they exit, so the fifth cannot make progress until then. This is 74 // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an 75 // intermediate point (ie, before any of them have exited). We have no way to anticipate 76 // or enforce that user threads will not attempt such intermediate synchronization. 77 // The only way to ensure safety is to avoid imposing a cap on the number of handles. 78 std::unordered_map<int, std::vector<Handle>> created_handles; 79 std::unordered_map<int, std::vector<Handle_t>> available_handles; 80 81 // PoolWindow lazily creates and caches the handles that a particular thread is using, 82 // so in the common case handle access doesn't incur either handle creation or a mutex lock. 83 class PoolWindow 84 { 85 public: PoolWindowDeviceThreadHandlePool86 PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {} ~PoolWindowDeviceThreadHandlePool87 ~PoolWindow(){ release(); } 88 reserveDeviceThreadHandlePool89 Handle_t reserve(int device) 90 { 91 // If this thread already has a handle for this device, return it 92 if(my_handles.find(device) != my_handles.end()) 93 return my_handles[device]; 94 95 // otherwise, either grab a handle from the pool if one is available, 96 // or if not, create a new one. 97 auto parent = weak_parent.lock(); 98 TORCH_CHECK(parent, "Cannot create handle during program termination"); 99 std::lock_guard<std::mutex> guard(parent->mutex); 100 101 if(parent->available_handles[device].size() > 0) 102 { 103 my_handles[device] = parent->available_handles[device].back(); 104 parent->available_handles[device].pop_back(); 105 } 106 else 107 { 108 // In local testing, I do observe that emplace_back sometimes routes through temporaries 109 // that incur move-constructor and destructor calls. See comments in Handle above. 110 parent->created_handles[device].emplace_back(true /*create*/); 111 my_handles[device] = parent->created_handles[device].back().handle; 112 } 113 114 return my_handles[device]; 115 } 116 117 private: 118 // Stores the per-device handles currently owned by this thread 119 std::unordered_map<int, Handle_t> my_handles; 120 121 std::weak_ptr<DeviceThreadHandlePool> weak_parent; 122 123 // Called by the destructor. Releases this thread's handles back into the pool. releaseDeviceThreadHandlePool124 void release() { 125 if(my_handles.size() > 0) { 126 auto parent = weak_parent.lock(); 127 if (!parent) { 128 // If this thread exits after atexit handlers have completed, the 129 // cuda context itself may be invalid, so we must leak the handles. 130 return; 131 } 132 133 std::lock_guard<std::mutex> guard(parent->mutex); 134 for(auto d_h : my_handles) 135 parent->available_handles[d_h.first].push_back(d_h.second); 136 } 137 } 138 }; 139 140 // Warning: 141 // If you want to change this function, be aware that this function will be called 142 // by multiple threads and there is no mutex guarding the call of this function, so 143 // make sure your implementation is thread-safe. newPoolWindowDeviceThreadHandlePool144 PoolWindow *newPoolWindow() { 145 // The returned pointer will be owned by a thread local variable 146 // so that different threads does not share the same PoolWindow. 147 return new PoolWindow(this->shared_from_this()); 148 } 149 }; 150 151 }} // namespace at::cuda::detail::<anonymous> 152