xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAStream.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAGuard.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAStream.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/CallOnce.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #include <array>
10*da0073e9SAndroid Build Coastguard Worker #include <atomic>
11*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker namespace {
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker // Global stream state and constants
18*da0073e9SAndroid Build Coastguard Worker static c10::once_flag init_flag;
19*da0073e9SAndroid Build Coastguard Worker static DeviceIndex num_gpus = -1;
20*da0073e9SAndroid Build Coastguard Worker static constexpr int kStreamsPerPoolBits = 5;
21*da0073e9SAndroid Build Coastguard Worker static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
22*da0073e9SAndroid Build Coastguard Worker static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
23*da0073e9SAndroid Build Coastguard Worker static constexpr int kStreamTypeBits = 4;
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker static int max_stream_priorities;
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker // Non-default streams
28*da0073e9SAndroid Build Coastguard Worker // Note: the number of CUDA devices is determined at run time,
29*da0073e9SAndroid Build Coastguard Worker // and the low and high priority pools are lazily initialized
30*da0073e9SAndroid Build Coastguard Worker // when the first stream is requested for a device.
31*da0073e9SAndroid Build Coastguard Worker // The device flags track the initialization of each device, while
32*da0073e9SAndroid Build Coastguard Worker // the low and high priority counters track, for each device, the next stream
33*da0073e9SAndroid Build Coastguard Worker // in the pool to be returned when a stream is requested (round-robin fashion
34*da0073e9SAndroid Build Coastguard Worker // , see the note in CUDAStream.h).
35*da0073e9SAndroid Build Coastguard Worker // The streams are "leaked": they are created but never destroyed because the
36*da0073e9SAndroid Build Coastguard Worker // destruction of global variables could happen after the CUDA runtime has
37*da0073e9SAndroid Build Coastguard Worker // already been destroyed and thus invoking cudaStreamDestroy could lead to a
38*da0073e9SAndroid Build Coastguard Worker // crash. It's likely an issue in CUDA, but to be safe - let's just "forget"
39*da0073e9SAndroid Build Coastguard Worker // the destruction.
40*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM)
41*da0073e9SAndroid Build Coastguard Worker // CUDA-only: used to initializes the stream pools (once)
42*da0073e9SAndroid Build Coastguard Worker static std::array<c10::once_flag, C10_COMPILE_TIME_MAX_GPUS> device_flags;
43*da0073e9SAndroid Build Coastguard Worker #endif
44*da0073e9SAndroid Build Coastguard Worker static std::array<
45*da0073e9SAndroid Build Coastguard Worker     std::array<std::atomic<uint32_t>, C10_COMPILE_TIME_MAX_GPUS>,
46*da0073e9SAndroid Build Coastguard Worker     c10::cuda::max_compile_time_stream_priorities>
47*da0073e9SAndroid Build Coastguard Worker     priority_counters;
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker static std::array<
50*da0073e9SAndroid Build Coastguard Worker     std::array<
51*da0073e9SAndroid Build Coastguard Worker         std::array<cudaStream_t, kStreamsPerPool>,
52*da0073e9SAndroid Build Coastguard Worker         C10_COMPILE_TIME_MAX_GPUS>,
53*da0073e9SAndroid Build Coastguard Worker     c10::cuda::max_compile_time_stream_priorities>
54*da0073e9SAndroid Build Coastguard Worker     streams;
55*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
56*da0073e9SAndroid Build Coastguard Worker static c10::once_flag
57*da0073e9SAndroid Build Coastguard Worker     stream_flags[c10::cuda::max_compile_time_stream_priorities]
58*da0073e9SAndroid Build Coastguard Worker                 [C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
59*da0073e9SAndroid Build Coastguard Worker #endif
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker // Note [HIP Lazy Streams]
62*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~
63*da0073e9SAndroid Build Coastguard Worker // For ROCm/HIP, each stream is lazily initialized rather than creating all
64*da0073e9SAndroid Build Coastguard Worker // streams when the first stream is requested. HIP streams are not as
65*da0073e9SAndroid Build Coastguard Worker // lightweight as CUDA streams; the pooling strategy can affect performance.
66*da0073e9SAndroid Build Coastguard Worker // Rather than changing the pooling implementation, ROCm/HIP will lazy init
67*da0073e9SAndroid Build Coastguard Worker // each stream when it is first requested.
68*da0073e9SAndroid Build Coastguard Worker 
69*da0073e9SAndroid Build Coastguard Worker // Note [StreamId assignment]
70*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~~~~
71*da0073e9SAndroid Build Coastguard Worker // How do we assign stream IDs?
72*da0073e9SAndroid Build Coastguard Worker //
73*da0073e9SAndroid Build Coastguard Worker // -- 54 bits --  -- 5 bits -----  -- 4 bits --     --1 bit --
74*da0073e9SAndroid Build Coastguard Worker // zeros          stream id index  StreamIdType     Ext/native stream
75*da0073e9SAndroid Build Coastguard Worker //                ignored for ext   ignored for ext
76*da0073e9SAndroid Build Coastguard Worker // for external stream, StreamID is a cudaStream_t pointer
77*da0073e9SAndroid Build Coastguard Worker // this means that last bit will always be 0
78*da0073e9SAndroid Build Coastguard Worker // so when constructing StreamId for a native stream we set last bit to 1
79*da0073e9SAndroid Build Coastguard Worker // to distinguish between native and external streams
80*da0073e9SAndroid Build Coastguard Worker //
81*da0073e9SAndroid Build Coastguard Worker //
82*da0073e9SAndroid Build Coastguard Worker // We are obligated to treat the stream ID 0 as the default stream, per the
83*da0073e9SAndroid Build Coastguard Worker // invariant specified in c10::Stream, so this is one exception to
84*da0073e9SAndroid Build Coastguard Worker // "last bit = 1 for native streams". However, all other numbers are entirely
85*da0073e9SAndroid Build Coastguard Worker // an internal implementation detail, we reserve the right to renumber streams
86*da0073e9SAndroid Build Coastguard Worker // however we like.
87*da0073e9SAndroid Build Coastguard Worker //
88*da0073e9SAndroid Build Coastguard Worker // Note that it is really important that the MSB is zero; StreamId is a
89*da0073e9SAndroid Build Coastguard Worker // *signed* integer, and unsigned to signed conversion outside of the
90*da0073e9SAndroid Build Coastguard Worker // bounds of signed integer representation is undefined behavior.  You
91*da0073e9SAndroid Build Coastguard Worker // could work around this with something like
92*da0073e9SAndroid Build Coastguard Worker // https://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
93*da0073e9SAndroid Build Coastguard Worker // but it seems a bit overkill for this.
94*da0073e9SAndroid Build Coastguard Worker //
95*da0073e9SAndroid Build Coastguard Worker // Also, external managed stream pointers (cudaStream_t) can be directly stored
96*da0073e9SAndroid Build Coastguard Worker // in the Id field so in this case, we need to check the stream alignment.
97*da0073e9SAndroid Build Coastguard Worker 
98*da0073e9SAndroid Build Coastguard Worker class StreamIdType {
99*da0073e9SAndroid Build Coastguard Worker   // StreamIdType encodes whether this stream is DEFAULT, EXTernal or
100*da0073e9SAndroid Build Coastguard Worker   // for all other native streams, the stream priority (higher value is higher
101*da0073e9SAndroid Build Coastguard Worker   // priority)
102*da0073e9SAndroid Build Coastguard Worker  private:
103*da0073e9SAndroid Build Coastguard Worker   uint8_t stream_type;
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker  public:
106*da0073e9SAndroid Build Coastguard Worker   static const uint8_t DEFAULT = 0x0;
107*da0073e9SAndroid Build Coastguard Worker   static const uint8_t EXT = 0xF;
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker  public:
StreamIdType(const uint8_t _stream_type)110*da0073e9SAndroid Build Coastguard Worker   StreamIdType(const uint8_t _stream_type) : stream_type(_stream_type) {}
111*da0073e9SAndroid Build Coastguard Worker 
isExt() const112*da0073e9SAndroid Build Coastguard Worker   bool isExt() const {
113*da0073e9SAndroid Build Coastguard Worker     return EXT == stream_type;
114*da0073e9SAndroid Build Coastguard Worker   }
115*da0073e9SAndroid Build Coastguard Worker 
isDefault() const116*da0073e9SAndroid Build Coastguard Worker   bool isDefault() const {
117*da0073e9SAndroid Build Coastguard Worker     return DEFAULT == stream_type;
118*da0073e9SAndroid Build Coastguard Worker   }
119*da0073e9SAndroid Build Coastguard Worker 
getStreamType() const120*da0073e9SAndroid Build Coastguard Worker   uint8_t getStreamType() const {
121*da0073e9SAndroid Build Coastguard Worker     return stream_type;
122*da0073e9SAndroid Build Coastguard Worker   }
123*da0073e9SAndroid Build Coastguard Worker };
124*da0073e9SAndroid Build Coastguard Worker 
operator <<(std::ostream & stream,StreamIdType s)125*da0073e9SAndroid Build Coastguard Worker std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
126*da0073e9SAndroid Build Coastguard Worker   if (s.isDefault()) {
127*da0073e9SAndroid Build Coastguard Worker     stream << "DEFAULT";
128*da0073e9SAndroid Build Coastguard Worker   } else if (s.isExt()) {
129*da0073e9SAndroid Build Coastguard Worker     stream << "EXT";
130*da0073e9SAndroid Build Coastguard Worker   } else {
131*da0073e9SAndroid Build Coastguard Worker     stream << "PRIORITY " << int(s.getStreamType());
132*da0073e9SAndroid Build Coastguard Worker   }
133*da0073e9SAndroid Build Coastguard Worker   return stream;
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker 
136*da0073e9SAndroid Build Coastguard Worker // StreamId is 64-bit, so we can just rely on regular promotion rules.
137*da0073e9SAndroid Build Coastguard Worker // We rely on streamIdIndex and streamIdType being non-negative;
138*da0073e9SAndroid Build Coastguard Worker // see Note [Hazard when concatenating signed integers]
139*da0073e9SAndroid Build Coastguard Worker 
streamIdType(StreamId s)140*da0073e9SAndroid Build Coastguard Worker static inline StreamIdType streamIdType(StreamId s) {
141*da0073e9SAndroid Build Coastguard Worker   // Externally allocated streams have their id being the cudaStream_ptr
142*da0073e9SAndroid Build Coastguard Worker   // so the last bit will be 0
143*da0073e9SAndroid Build Coastguard Worker   if ((!(s & 1)) && s) {
144*da0073e9SAndroid Build Coastguard Worker     return StreamIdType(StreamIdType::EXT);
145*da0073e9SAndroid Build Coastguard Worker   }
146*da0073e9SAndroid Build Coastguard Worker   // last bit is external/internal stream, the mask should start from second
147*da0073e9SAndroid Build Coastguard Worker   // rightmost bit
148*da0073e9SAndroid Build Coastguard Worker   int mask_for_type = (1 << kStreamTypeBits) - 1;
149*da0073e9SAndroid Build Coastguard Worker   auto val = (s >> 1) & mask_for_type;
150*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(val || !(s & 1), "invalid StreamId", s);
151*da0073e9SAndroid Build Coastguard Worker   return StreamIdType(val);
152*da0073e9SAndroid Build Coastguard Worker }
153*da0073e9SAndroid Build Coastguard Worker 
streamIdIndex(StreamId s)154*da0073e9SAndroid Build Coastguard Worker static inline size_t streamIdIndex(StreamId s) {
155*da0073e9SAndroid Build Coastguard Worker   return static_cast<size_t>(
156*da0073e9SAndroid Build Coastguard Worker       (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1));
157*da0073e9SAndroid Build Coastguard Worker }
158*da0073e9SAndroid Build Coastguard Worker 
makeStreamId(StreamIdType st,size_t si)159*da0073e9SAndroid Build Coastguard Worker StreamId makeStreamId(StreamIdType st, size_t si) {
160*da0073e9SAndroid Build Coastguard Worker   if (st.isDefault()) {
161*da0073e9SAndroid Build Coastguard Worker     return static_cast<StreamId>(0);
162*da0073e9SAndroid Build Coastguard Worker   }
163*da0073e9SAndroid Build Coastguard Worker   return (static_cast<StreamId>(si) << (kStreamTypeBits + 1)) |
164*da0073e9SAndroid Build Coastguard Worker       static_cast<StreamId>(st.getStreamType() << 1) | 1;
165*da0073e9SAndroid Build Coastguard Worker }
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker // Thread-local current streams
168*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-arrays)
169*da0073e9SAndroid Build Coastguard Worker static thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
170*da0073e9SAndroid Build Coastguard Worker 
171*da0073e9SAndroid Build Coastguard Worker // Populates global values.
172*da0073e9SAndroid Build Coastguard Worker // Warning: this function must only be called once!
initGlobalStreamState()173*da0073e9SAndroid Build Coastguard Worker static void initGlobalStreamState() {
174*da0073e9SAndroid Build Coastguard Worker   num_gpus = device_count();
175*da0073e9SAndroid Build Coastguard Worker   // Check if the number of GPUs matches the expected compile-time max number
176*da0073e9SAndroid Build Coastguard Worker   // of GPUs.
177*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
178*da0073e9SAndroid Build Coastguard Worker       num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
179*da0073e9SAndroid Build Coastguard Worker       "Number of CUDA devices on the machine is larger than the compiled "
180*da0073e9SAndroid Build Coastguard Worker       "max number of gpus expected (",
181*da0073e9SAndroid Build Coastguard Worker       C10_COMPILE_TIME_MAX_GPUS,
182*da0073e9SAndroid Build Coastguard Worker       "). Increase that and recompile.");
183*da0073e9SAndroid Build Coastguard Worker   int leastPriority = -1, greatestPriority = -1;
184*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(
185*da0073e9SAndroid Build Coastguard Worker       cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority));
186*da0073e9SAndroid Build Coastguard Worker   // Note [HIP stream priorities]
187*da0073e9SAndroid Build Coastguard Worker   // HIP stream priorities are 1=low, 0=default, -1=high which differs from CUDA
188*da0073e9SAndroid Build Coastguard Worker   // which is 0=default, -1=high, -2=higher etc.
189*da0073e9SAndroid Build Coastguard Worker   // Clamp leastPriority to 0 for HIP.
190*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
191*da0073e9SAndroid Build Coastguard Worker   leastPriority = 0;
192*da0073e9SAndroid Build Coastguard Worker #endif
193*da0073e9SAndroid Build Coastguard Worker   // greatestPriority is negative
194*da0073e9SAndroid Build Coastguard Worker   auto range = leastPriority - greatestPriority + 1;
195*da0073e9SAndroid Build Coastguard Worker   max_stream_priorities = range >= c10::cuda::max_compile_time_stream_priorities
196*da0073e9SAndroid Build Coastguard Worker       ? c10::cuda::max_compile_time_stream_priorities
197*da0073e9SAndroid Build Coastguard Worker       : range;
198*da0073e9SAndroid Build Coastguard Worker }
199*da0073e9SAndroid Build Coastguard Worker 
200*da0073e9SAndroid Build Coastguard Worker // Init a single CUDA or HIP stream
201*da0073e9SAndroid Build Coastguard Worker // See Note [HIP Lazy Streams]
initSingleStream(int p,DeviceIndex device_index,int i)202*da0073e9SAndroid Build Coastguard Worker static void initSingleStream(int p, DeviceIndex device_index, int i) {
203*da0073e9SAndroid Build Coastguard Worker   auto& stream = streams[p][device_index][i];
204*da0073e9SAndroid Build Coastguard Worker   auto pri = -p; // lower number is higher priority
205*da0073e9SAndroid Build Coastguard Worker 
206*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
207*da0073e9SAndroid Build Coastguard Worker   const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
208*da0073e9SAndroid Build Coastguard Worker   if (C10_UNLIKELY(interp)) {
209*da0073e9SAndroid Build Coastguard Worker     (*interp)->trace_gpu_stream_creation(
210*da0073e9SAndroid Build Coastguard Worker         c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
211*da0073e9SAndroid Build Coastguard Worker     priority_counters[p][device_index] = 0;
212*da0073e9SAndroid Build Coastguard Worker   }
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker 
215*da0073e9SAndroid Build Coastguard Worker // Creates the low and high priority stream pools for the specified device
216*da0073e9SAndroid Build Coastguard Worker // Warning: only call once per device!
initDeviceStreamState(DeviceIndex device_index)217*da0073e9SAndroid Build Coastguard Worker static void initDeviceStreamState(DeviceIndex device_index) {
218*da0073e9SAndroid Build Coastguard Worker   // Switches to the requested device so streams are properly associated
219*da0073e9SAndroid Build Coastguard Worker   // with it.
220*da0073e9SAndroid Build Coastguard Worker   CUDAGuard device_guard{device_index};
221*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(kStreamsPerPool)) {
222*da0073e9SAndroid Build Coastguard Worker     for (const auto p : c10::irange(max_stream_priorities)) {
223*da0073e9SAndroid Build Coastguard Worker       initSingleStream(p, device_index, i);
224*da0073e9SAndroid Build Coastguard Worker     }
225*da0073e9SAndroid Build Coastguard Worker   }
226*da0073e9SAndroid Build Coastguard Worker }
227*da0073e9SAndroid Build Coastguard Worker 
228*da0073e9SAndroid Build Coastguard Worker // Init front-end to ensure initialization only occurs once
initCUDAStreamsOnce()229*da0073e9SAndroid Build Coastguard Worker static void initCUDAStreamsOnce() {
230*da0073e9SAndroid Build Coastguard Worker   // Inits default streams (once, globally)
231*da0073e9SAndroid Build Coastguard Worker   c10::call_once(init_flag, initGlobalStreamState);
232*da0073e9SAndroid Build Coastguard Worker 
233*da0073e9SAndroid Build Coastguard Worker   if (current_streams) {
234*da0073e9SAndroid Build Coastguard Worker     return;
235*da0073e9SAndroid Build Coastguard Worker   }
236*da0073e9SAndroid Build Coastguard Worker 
237*da0073e9SAndroid Build Coastguard Worker   // Inits current streams (thread local) to default streams
238*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-arrays)
239*da0073e9SAndroid Build Coastguard Worker   current_streams = std::make_unique<StreamId[]>(num_gpus);
240*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(num_gpus)) {
241*da0073e9SAndroid Build Coastguard Worker     current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0);
242*da0073e9SAndroid Build Coastguard Worker   }
243*da0073e9SAndroid Build Coastguard Worker }
244*da0073e9SAndroid Build Coastguard Worker 
245*da0073e9SAndroid Build Coastguard Worker // Helper to verify the GPU index is valid
check_gpu(DeviceIndex device_index)246*da0073e9SAndroid Build Coastguard Worker static inline void check_gpu(DeviceIndex device_index) {
247*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(device_index >= 0 && device_index < num_gpus);
248*da0073e9SAndroid Build Coastguard Worker }
249*da0073e9SAndroid Build Coastguard Worker 
250*da0073e9SAndroid Build Coastguard Worker // Helper to determine the index of the stream to return
251*da0073e9SAndroid Build Coastguard Worker // Note: Streams are returned round-robin (see note in CUDAStream.h)
get_idx(std::atomic<uint32_t> & counter)252*da0073e9SAndroid Build Coastguard Worker static uint32_t get_idx(std::atomic<uint32_t>& counter) {
253*da0073e9SAndroid Build Coastguard Worker   auto raw_idx = counter++;
254*da0073e9SAndroid Build Coastguard Worker   return raw_idx % kStreamsPerPool;
255*da0073e9SAndroid Build Coastguard Worker }
256*da0073e9SAndroid Build Coastguard Worker 
CUDAStreamForId(DeviceIndex device_index,StreamId stream_id)257*da0073e9SAndroid Build Coastguard Worker CUDAStream CUDAStreamForId(DeviceIndex device_index, StreamId stream_id) {
258*da0073e9SAndroid Build Coastguard Worker   return CUDAStream(
259*da0073e9SAndroid Build Coastguard Worker       CUDAStream::UNCHECKED,
260*da0073e9SAndroid Build Coastguard Worker       Stream(
261*da0073e9SAndroid Build Coastguard Worker           Stream::UNSAFE,
262*da0073e9SAndroid Build Coastguard Worker           c10::Device(DeviceType::CUDA, device_index),
263*da0073e9SAndroid Build Coastguard Worker           stream_id));
264*da0073e9SAndroid Build Coastguard Worker }
265*da0073e9SAndroid Build Coastguard Worker 
266*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
267*da0073e9SAndroid Build Coastguard Worker 
268*da0073e9SAndroid Build Coastguard Worker // See Note [StreamId assignment]
stream() const269*da0073e9SAndroid Build Coastguard Worker cudaStream_t CUDAStream::stream() const {
270*da0073e9SAndroid Build Coastguard Worker   c10::DeviceIndex device_index = stream_.device_index();
271*da0073e9SAndroid Build Coastguard Worker   StreamId stream_id = stream_.id();
272*da0073e9SAndroid Build Coastguard Worker   StreamIdType st = streamIdType(stream_id);
273*da0073e9SAndroid Build Coastguard Worker   size_t si = streamIdIndex(stream_id);
274*da0073e9SAndroid Build Coastguard Worker   if (st.isDefault()) {
275*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
276*da0073e9SAndroid Build Coastguard Worker         si == 0,
277*da0073e9SAndroid Build Coastguard Worker         "Unrecognized stream ",
278*da0073e9SAndroid Build Coastguard Worker         stream_,
279*da0073e9SAndroid Build Coastguard Worker         " (I think this should be the default stream, but I got a non-zero index ",
280*da0073e9SAndroid Build Coastguard Worker         si,
281*da0073e9SAndroid Build Coastguard Worker         ").",
282*da0073e9SAndroid Build Coastguard Worker         " Did you manufacture the StreamId yourself?  Don't do that; use the",
283*da0073e9SAndroid Build Coastguard Worker         " official API like c10::cuda::getStreamFromPool() to get a new stream.");
284*da0073e9SAndroid Build Coastguard Worker     return nullptr;
285*da0073e9SAndroid Build Coastguard Worker   } else if (st.isExt()) {
286*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-no-int-to-ptr)
287*da0073e9SAndroid Build Coastguard Worker     return reinterpret_cast<cudaStream_t>(stream_id);
288*da0073e9SAndroid Build Coastguard Worker   } else {
289*da0073e9SAndroid Build Coastguard Worker     auto streamType = st.getStreamType();
290*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
291*da0073e9SAndroid Build Coastguard Worker         streamType >= 1 && streamType <= max_stream_priorities,
292*da0073e9SAndroid Build Coastguard Worker         "Unrecognized stream ",
293*da0073e9SAndroid Build Coastguard Worker         stream_,
294*da0073e9SAndroid Build Coastguard Worker         " (I didn't recognize the stream type, ",
295*da0073e9SAndroid Build Coastguard Worker         st,
296*da0073e9SAndroid Build Coastguard Worker         " with the value ",
297*da0073e9SAndroid Build Coastguard Worker         streamType,
298*da0073e9SAndroid Build Coastguard Worker         ")");
299*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
300*da0073e9SAndroid Build Coastguard Worker     // See Note [HIP Lazy Streams]
301*da0073e9SAndroid Build Coastguard Worker     c10::call_once(
302*da0073e9SAndroid Build Coastguard Worker         stream_flags[st.getStreamType() - 1][device_index][si],
303*da0073e9SAndroid Build Coastguard Worker         initSingleStream,
304*da0073e9SAndroid Build Coastguard Worker         st.getStreamType() - 1,
305*da0073e9SAndroid Build Coastguard Worker         device_index,
306*da0073e9SAndroid Build Coastguard Worker         si);
307*da0073e9SAndroid Build Coastguard Worker #endif
308*da0073e9SAndroid Build Coastguard Worker     return streams[st.getStreamType() - 1][device_index][si];
309*da0073e9SAndroid Build Coastguard Worker   }
310*da0073e9SAndroid Build Coastguard Worker }
311*da0073e9SAndroid Build Coastguard Worker 
312*da0073e9SAndroid Build Coastguard Worker // Returns a stream from the requested pool
313*da0073e9SAndroid Build Coastguard Worker // Note: when called the first time on a device, this will create the
314*da0073e9SAndroid Build Coastguard Worker // stream pools for that device.
getStreamFromPool(const int priority,DeviceIndex device_index)315*da0073e9SAndroid Build Coastguard Worker CUDAStream getStreamFromPool(const int priority, DeviceIndex device_index) {
316*da0073e9SAndroid Build Coastguard Worker   initCUDAStreamsOnce();
317*da0073e9SAndroid Build Coastguard Worker   if (device_index == -1) {
318*da0073e9SAndroid Build Coastguard Worker     device_index = current_device();
319*da0073e9SAndroid Build Coastguard Worker     c10::cuda::SetTargetDevice();
320*da0073e9SAndroid Build Coastguard Worker   }
321*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
322*da0073e9SAndroid Build Coastguard Worker       priority <= 0,
323*da0073e9SAndroid Build Coastguard Worker       "Expected cuda stream priority to be less than or equal to 0, got ",
324*da0073e9SAndroid Build Coastguard Worker       priority);
325*da0073e9SAndroid Build Coastguard Worker   check_gpu(device_index);
326*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM)
327*da0073e9SAndroid Build Coastguard Worker   // See Note [HIP Lazy Streams]
328*da0073e9SAndroid Build Coastguard Worker   // CUDA-only: Initializes the stream pools (once)
329*da0073e9SAndroid Build Coastguard Worker   c10::call_once(
330*da0073e9SAndroid Build Coastguard Worker       device_flags[device_index], initDeviceStreamState, device_index);
331*da0073e9SAndroid Build Coastguard Worker #endif
332*da0073e9SAndroid Build Coastguard Worker   auto pri_idx = -priority;
333*da0073e9SAndroid Build Coastguard Worker   pri_idx =
334*da0073e9SAndroid Build Coastguard Worker       std::min(pri_idx, max_stream_priorities - 1); // pri_idx is zero-based
335*da0073e9SAndroid Build Coastguard Worker   const auto idx = get_idx(priority_counters[pri_idx][device_index]);
336*da0073e9SAndroid Build Coastguard Worker   StreamIdType id_type = StreamIdType(pri_idx + 1);
337*da0073e9SAndroid Build Coastguard Worker   return CUDAStreamForId(device_index, makeStreamId(id_type, idx));
338*da0073e9SAndroid Build Coastguard Worker }
339*da0073e9SAndroid Build Coastguard Worker 
getStreamFromPool(const bool isHighPriority,DeviceIndex device)340*da0073e9SAndroid Build Coastguard Worker CUDAStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) {
341*da0073e9SAndroid Build Coastguard Worker   initCUDAStreamsOnce();
342*da0073e9SAndroid Build Coastguard Worker   int priority = isHighPriority ? -max_stream_priorities + 1 : 0;
343*da0073e9SAndroid Build Coastguard Worker   return getStreamFromPool(priority, device);
344*da0073e9SAndroid Build Coastguard Worker }
345*da0073e9SAndroid Build Coastguard Worker 
getStreamFromExternal(cudaStream_t ext_stream,DeviceIndex device_index)346*da0073e9SAndroid Build Coastguard Worker CUDAStream getStreamFromExternal(
347*da0073e9SAndroid Build Coastguard Worker     cudaStream_t ext_stream,
348*da0073e9SAndroid Build Coastguard Worker     DeviceIndex device_index) {
349*da0073e9SAndroid Build Coastguard Worker   // The stream pointer will be the actual id
350*da0073e9SAndroid Build Coastguard Worker   return CUDAStreamForId(device_index, reinterpret_cast<int64_t>(ext_stream));
351*da0073e9SAndroid Build Coastguard Worker }
352*da0073e9SAndroid Build Coastguard Worker 
getDefaultCUDAStream(DeviceIndex device_index)353*da0073e9SAndroid Build Coastguard Worker CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
354*da0073e9SAndroid Build Coastguard Worker   initCUDAStreamsOnce();
355*da0073e9SAndroid Build Coastguard Worker   if (device_index == -1) {
356*da0073e9SAndroid Build Coastguard Worker     device_index = current_device();
357*da0073e9SAndroid Build Coastguard Worker     c10::cuda::SetTargetDevice();
358*da0073e9SAndroid Build Coastguard Worker   }
359*da0073e9SAndroid Build Coastguard Worker   check_gpu(device_index);
360*da0073e9SAndroid Build Coastguard Worker   return CUDAStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0));
361*da0073e9SAndroid Build Coastguard Worker }
362*da0073e9SAndroid Build Coastguard Worker 
getCurrentCUDAStream(DeviceIndex device_index)363*da0073e9SAndroid Build Coastguard Worker CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
364*da0073e9SAndroid Build Coastguard Worker   initCUDAStreamsOnce();
365*da0073e9SAndroid Build Coastguard Worker   if (device_index == -1) {
366*da0073e9SAndroid Build Coastguard Worker     device_index = current_device();
367*da0073e9SAndroid Build Coastguard Worker     c10::cuda::SetTargetDevice();
368*da0073e9SAndroid Build Coastguard Worker   }
369*da0073e9SAndroid Build Coastguard Worker   check_gpu(device_index);
370*da0073e9SAndroid Build Coastguard Worker   return CUDAStreamForId(device_index, current_streams[device_index]);
371*da0073e9SAndroid Build Coastguard Worker }
372*da0073e9SAndroid Build Coastguard Worker 
setCurrentCUDAStream(CUDAStream stream)373*da0073e9SAndroid Build Coastguard Worker void setCurrentCUDAStream(CUDAStream stream) {
374*da0073e9SAndroid Build Coastguard Worker   initCUDAStreamsOnce();
375*da0073e9SAndroid Build Coastguard Worker   current_streams[stream.device_index()] = stream.id();
376*da0073e9SAndroid Build Coastguard Worker }
377*da0073e9SAndroid Build Coastguard Worker 
operator <<(std::ostream & stream,const CUDAStream & s)378*da0073e9SAndroid Build Coastguard Worker std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
379*da0073e9SAndroid Build Coastguard Worker   return stream << s.unwrap();
380*da0073e9SAndroid Build Coastguard Worker }
381*da0073e9SAndroid Build Coastguard Worker 
382*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
383