xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/DeviceThreadHandles.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
2 // These handles are tied to device, and these libraries requires/recommends not to
3 // share handles across host threads.
4 //
5 // These libraries recommend using one handle per host thread. We may not want to do
6 // this because threads are relatively light-weight, but creating and destroying
7 // handles is expensive (destroying the handle causes synchronizations). DataParallel,
8 // for example, creates new threads for each forward pass.
9 //
10 // This file implements a handle pool mechanism. The handle pool returns handles on
11 // demand as threads request them. If all existing handles in the pool are in use,
12 // it creates a new one. As threads terminate, they release handles back into the pool.
13 // In this way, the handle pool never creates more handles than the high-water mark of
14 // active threads, so it's efficient with DataParallel.
15 
16 #pragma once
17 
18 #include <unordered_map>
19 #include <vector>
20 #include <utility>
21 #include <mutex>
22 #include <memory>
23 
24 #include <c10/util/Exception.h>
25 
26 namespace at::cuda { namespace {
27 
28 template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
29 struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
30 
31     struct Handle {
32     Handle_t handle;
handleDeviceThreadHandlePool::Handle33     Handle(bool create = false) : handle(nullptr)
34     {
35         if(create) Create(&handle);
36     }
37     // std::vector.emplace() and push_back() may route through temporaries and call
38     // copy/move constructors along the way.  If this is the case, we don't want
39     // the destructors of temporaries to call cudnnDestroy on the handle.
40     // We can achieve safety (for the narrow case of stashing within std::vectors)
41     // by making Handle moveable but not copyable, and transferring handle ownership
42     // to the latest constructed object.  This is not a substitute for full-blown
43     // reference counting, but reference counting may be overkill here.
44     // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
45     // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
46     Handle(const Handle& rhs) = delete;
47     // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
HandleDeviceThreadHandlePool::Handle48     Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); }
49     // operator= takes argument by value
50     Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
~HandleDeviceThreadHandlePool::Handle51     ~Handle() {
52         if(handle) Destroy(handle);
53     }
54     };
55 
56     std::mutex mutex;
57 
58     // Handles are lazily created as different threads request them,
59     // but are never destroyed until the end of the process.
60     // The maximum number of handles this process will create for each device is equal
61     // to the high-water mark of the number of concurrently active threads that request
62     // handles for that device.
63     // When threads terminate, they release their handles back into the pool for reuse.
64     // Otherwise, new handles would be created every time new threads were spawned,
65     // resulting in poor performance for Python modules that repeatedly or frequently
66     // spawned new sets of threads (like DataParallel, which creates a new set of threads
67     // for each forward pass).
68     //
69     // To prevent potential deadlocks, we explicitly choose not to cap the number
70     // of handles that are created per device.
71     // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
72     // only 4 can make forward progress at any time. The other 4 will not release their
73     // handles until they exit, so the fifth cannot make progress until then.  This is
74     // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
75     // intermediate point (ie, before any of them have exited).  We have no way to anticipate
76     // or enforce that user threads will not attempt such intermediate synchronization.
77     // The only way to ensure safety is to avoid imposing a cap on the number of handles.
78     std::unordered_map<int, std::vector<Handle>> created_handles;
79     std::unordered_map<int, std::vector<Handle_t>> available_handles;
80 
81     // PoolWindow lazily creates and caches the handles that a particular thread is using,
82     // so in the common case handle access doesn't incur either handle creation or a mutex lock.
83     class PoolWindow
84     {
85     public:
PoolWindowDeviceThreadHandlePool86     PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
~PoolWindowDeviceThreadHandlePool87     ~PoolWindow(){ release(); }
88 
reserveDeviceThreadHandlePool89     Handle_t reserve(int device)
90     {
91         // If this thread already has a handle for this device, return it
92         if(my_handles.find(device) != my_handles.end())
93         return my_handles[device];
94 
95         // otherwise, either grab a handle from the pool if one is available,
96         // or if not, create a new one.
97         auto parent = weak_parent.lock();
98         TORCH_CHECK(parent, "Cannot create handle during program termination");
99         std::lock_guard<std::mutex> guard(parent->mutex);
100 
101         if(parent->available_handles[device].size() > 0)
102         {
103         my_handles[device] = parent->available_handles[device].back();
104         parent->available_handles[device].pop_back();
105         }
106         else
107         {
108         // In local testing, I do observe that emplace_back sometimes routes through temporaries
109         // that incur move-constructor and destructor calls.  See comments in Handle above.
110         parent->created_handles[device].emplace_back(true /*create*/);
111         my_handles[device] = parent->created_handles[device].back().handle;
112         }
113 
114         return my_handles[device];
115     }
116 
117     private:
118     // Stores the per-device handles currently owned by this thread
119     std::unordered_map<int, Handle_t> my_handles;
120 
121     std::weak_ptr<DeviceThreadHandlePool> weak_parent;
122 
123     // Called by the destructor.  Releases this thread's handles back into the pool.
releaseDeviceThreadHandlePool124     void release() {
125         if(my_handles.size() > 0) {
126             auto parent = weak_parent.lock();
127             if (!parent) {
128                 // If this thread exits after atexit handlers have completed, the
129                 // cuda context itself may be invalid, so we must leak the handles.
130                 return;
131             }
132 
133             std::lock_guard<std::mutex> guard(parent->mutex);
134             for(auto d_h : my_handles)
135                 parent->available_handles[d_h.first].push_back(d_h.second);
136         }
137     }
138     };
139 
140     // Warning:
141     // If you want to change this function, be aware that this function will be called
142     // by multiple threads and there is no mutex guarding the call of this function, so
143     // make sure your implementation is thread-safe.
newPoolWindowDeviceThreadHandlePool144     PoolWindow *newPoolWindow() {
145         // The returned pointer will be owned by a thread local variable
146         // so that different threads does not share the same PoolWindow.
147         return new PoolWindow(this->shared_from_this());
148     }
149 };
150 
151 }}  // namespace at::cuda::detail::<anonymous>
152