1 #pragma once 2 3 #include <cuda_runtime_api.h> 4 5 #include <c10/core/DeviceGuard.h> 6 #include <c10/core/Stream.h> 7 #include <c10/cuda/CUDAFunctions.h> 8 #include <c10/util/Exception.h> 9 10 /* 11 * Stream pool note. 12 * 13 * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams 14 * are backed by cuStreams, but they use several pools to minimize the costs 15 * associated with creating, retaining, and destroying cuStreams. 16 * 17 * There are three pools per device, and a device's pools are lazily created. 18 * 19 * The first pool contains only the default stream. When the default stream 20 * is requested it's returned. 21 * 22 * The second pool is the "low priority" or "default priority" streams. In 23 * HIP builds there is no distinction between streams in this pool and streams 24 * in the third pool (below). There are 32 of these streams per device, and 25 * when a stream is requested one of these streams is returned round-robin. 26 * That is, the first stream requested is at index 0, the second at index 1... 27 * to index 31, then index 0 again. 28 * 29 * This means that if 33 low priority streams are requested, the first and 30 * last streams requested are actually the same stream (under the covers) 31 * and kernels enqueued on them cannot run concurrently. 32 * 33 * The third pool is the "high priority" streams. The third pool acts like 34 * the second pool except the streams are created with a higher priority. 35 * 36 * These pools suggest that stream users should prefer many short-lived streams, 37 * as the cost of acquiring and releasing streams is effectively zero. If 38 * many longer-lived streams are required in performance critical scenarios 39 * then the functionality here may need to be extended to allow, for example, 40 * "reserving" a subset of the pool so that other streams do not accidentally 41 * overlap the performance critical streams. 42 * 43 * Note: although the notion of "current stream for device" is thread local 44 * (every OS thread has a separate current stream, as one might expect), 45 * the stream pool is global across all threads; stream 0 is always stream 0 46 * no matter which thread you use it on. Multiple threads can synchronize 47 * on the same stream. Although the CUDA documentation is not very clear 48 * on the matter, streams are thread safe; e.g., it is safe to enqueue 49 * a kernel on the same stream from two different threads. 50 */ 51 52 namespace c10::cuda { 53 54 static constexpr int max_compile_time_stream_priorities = 4; 55 56 // Value object representing a CUDA stream. This is just a wrapper 57 // around c10::Stream, but it comes with a little extra CUDA-specific 58 // functionality (conversion to cudaStream_t), and a guarantee that 59 // the wrapped c10::Stream really is a CUDA stream. 60 class C10_CUDA_API CUDAStream { 61 public: 62 enum Unchecked { UNCHECKED }; 63 64 /// Construct a CUDAStream from a Stream. This construction is checked, 65 /// and will raise an error if the Stream is not, in fact, a CUDA stream. CUDAStream(Stream stream)66 explicit CUDAStream(Stream stream) : stream_(stream) { 67 TORCH_CHECK(stream_.device_type() == DeviceType::CUDA); 68 } 69 70 /// Construct a CUDAStream from a Stream with no error checking. 71 /// This constructor uses the "named" constructor idiom, and can 72 /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream) CUDAStream(Unchecked,Stream stream)73 explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {} 74 75 bool operator==(const CUDAStream& other) const noexcept { 76 return unwrap() == other.unwrap(); 77 } 78 79 bool operator!=(const CUDAStream& other) const noexcept { 80 return unwrap() != other.unwrap(); 81 } 82 83 /// Implicit conversion to cudaStream_t. cudaStream_t()84 operator cudaStream_t() const { 85 return stream(); 86 } 87 88 /// Implicit conversion to Stream (a.k.a., forget that the stream is a 89 /// CUDA stream). Stream()90 operator Stream() const { 91 return unwrap(); 92 } 93 94 /// Used to avoid baking in device type explicitly to Python-side API. device_type()95 DeviceType device_type() const { 96 return DeviceType::CUDA; 97 } 98 99 /// Get the CUDA device index that this stream is associated with. device_index()100 DeviceIndex device_index() const { 101 return stream_.device_index(); 102 } 103 104 /// Get the full Device that this stream is associated with. The Device 105 /// is guaranteed to be a CUDA device. device()106 Device device() const { 107 return Device(DeviceType::CUDA, device_index()); 108 } 109 110 /// Return the stream ID corresponding to this particular stream. id()111 StreamId id() const { 112 return stream_.id(); 113 } 114 query()115 bool query() const { 116 DeviceGuard guard{stream_.device()}; 117 cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream())); 118 119 if (err == cudaSuccess) { 120 return true; 121 } else if (err != cudaErrorNotReady) { 122 C10_CUDA_CHECK(err); 123 } else { 124 // ignore and clear the error if not ready 125 (void)cudaGetLastError(); 126 } 127 128 return false; 129 } 130 synchronize()131 void synchronize() const { 132 DeviceGuard guard{stream_.device()}; 133 c10::cuda::stream_synchronize(stream()); 134 } 135 priority()136 int priority() const { 137 DeviceGuard guard{stream_.device()}; 138 int priority = 0; 139 C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); 140 return priority; 141 } 142 143 /// Explicit conversion to cudaStream_t. 144 cudaStream_t stream() const; 145 146 /// Explicit conversion to Stream. unwrap()147 Stream unwrap() const { 148 return stream_; 149 } 150 151 /// Reversibly pack a CUDAStream into a struct representation. 152 /// Previously the stream's data was packed into a single int64_t, 153 /// as it was assumed the fields would not require more than 154 /// 64 bits of storage in total. 155 /// See https://github.com/pytorch/pytorch/issues/75854 156 /// for more information regarding newer platforms that may violate 157 /// this assumption. 158 /// 159 /// The CUDAStream can be unpacked using unpack(). pack3()160 struct c10::StreamData3 pack3() const { 161 return stream_.pack3(); 162 } 163 164 // Unpack a CUDAStream from the 3 fields generated by pack(). unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)165 static CUDAStream unpack3( 166 StreamId stream_id, 167 DeviceIndex device_index, 168 DeviceType device_type) { 169 return CUDAStream(Stream::unpack3(stream_id, device_index, device_type)); 170 } 171 priority_range()172 static std::tuple<int, int> priority_range() { 173 // Note: this returns the range of priority **supported by PyTorch**, not 174 // the range of priority **supported by CUDA**. The former is a subset of 175 // the latter. 176 int least_priority = 0, greatest_priority = 0; 177 C10_CUDA_CHECK( 178 cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); 179 #ifdef USE_ROCM 180 // See Note [HIP stream priorities] 181 TORCH_INTERNAL_ASSERT( 182 least_priority == 1, "Unexpected HIP stream priority range"); 183 least_priority = 0; 184 #else 185 TORCH_INTERNAL_ASSERT( 186 least_priority == 0, "Unexpected CUDA stream priority range"); 187 #endif 188 TORCH_INTERNAL_ASSERT( 189 greatest_priority <= -1, "Unexpected CUDA stream priority range"); 190 greatest_priority = std::max( 191 -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority); 192 return std::make_tuple(least_priority, greatest_priority); 193 } 194 195 // Deleted for now; use CUDAEvent::block instead 196 // void synchronize_with(const CUDAEvent& event) const; 197 198 private: 199 Stream stream_; 200 }; 201 202 /** 203 * Get a new stream from the CUDA stream pool. You can think of this 204 * as "creating" a new stream, but no such creation actually happens; 205 * instead, streams are preallocated from the pool and returned in a 206 * round-robin fashion. 207 * 208 * You can request a stream from the high priority pool by setting 209 * isHighPriority to true, or a stream for a specific device by setting device 210 * (defaulting to the current CUDA stream.) 211 */ 212 C10_API CUDAStream 213 getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); 214 // no default priority to disambiguate overloads 215 C10_API CUDAStream 216 getStreamFromPool(const int priority, DeviceIndex device = -1); 217 218 /** 219 * Get a CUDAStream from a externally allocated one. 220 * 221 * This is mainly for interoperability with different libraries where we 222 * want to operate on a non-torch allocated stream for data exchange or similar 223 * purposes 224 */ 225 C10_API CUDAStream 226 getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index); 227 228 /** 229 * Get the default CUDA stream, for the passed CUDA device, or for the 230 * current device if no device index is passed. The default stream is 231 * where most computation occurs when you aren't explicitly using 232 * streams. 233 */ 234 C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); 235 236 /** 237 * Get the current CUDA stream, for the passed CUDA device, or for the 238 * current device if no device index is passed. The current CUDA stream 239 * will usually be the default CUDA stream for the device, but it may 240 * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard' 241 * or 'CUDAStreamGuard'. 242 */ 243 C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); 244 245 /** 246 * Set the current stream on the device of the passed in stream to be 247 * the passed in stream. Yes, you read that right: this function 248 * has *nothing* to do with the current device: it toggles the current 249 * stream of the device of the passed stream. 250 * 251 * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead 252 * (which will switch both your current device and current stream in the way you 253 * expect, and reset it back to its original state afterwards). 254 */ 255 C10_API void setCurrentCUDAStream(CUDAStream stream); 256 257 C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s); 258 259 } // namespace c10::cuda 260 261 namespace std { 262 template <> 263 struct hash<c10::cuda::CUDAStream> { 264 size_t operator()(c10::cuda::CUDAStream s) const noexcept { 265 return std::hash<c10::Stream>{}(s.unwrap()); 266 } 267 }; 268 } // namespace std 269