1*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAGuard.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAStream.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/CallOnce.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker #include <array>
10*da0073e9SAndroid Build Coastguard Worker #include <atomic>
11*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker namespace {
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker // Global stream state and constants
18*da0073e9SAndroid Build Coastguard Worker static c10::once_flag init_flag;
19*da0073e9SAndroid Build Coastguard Worker static DeviceIndex num_gpus = -1;
20*da0073e9SAndroid Build Coastguard Worker static constexpr int kStreamsPerPoolBits = 5;
21*da0073e9SAndroid Build Coastguard Worker static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
22*da0073e9SAndroid Build Coastguard Worker static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
23*da0073e9SAndroid Build Coastguard Worker static constexpr int kStreamTypeBits = 4;
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker static int max_stream_priorities;
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker // Non-default streams
28*da0073e9SAndroid Build Coastguard Worker // Note: the number of CUDA devices is determined at run time,
29*da0073e9SAndroid Build Coastguard Worker // and the low and high priority pools are lazily initialized
30*da0073e9SAndroid Build Coastguard Worker // when the first stream is requested for a device.
31*da0073e9SAndroid Build Coastguard Worker // The device flags track the initialization of each device, while
32*da0073e9SAndroid Build Coastguard Worker // the low and high priority counters track, for each device, the next stream
33*da0073e9SAndroid Build Coastguard Worker // in the pool to be returned when a stream is requested (round-robin fashion
34*da0073e9SAndroid Build Coastguard Worker // , see the note in CUDAStream.h).
35*da0073e9SAndroid Build Coastguard Worker // The streams are "leaked": they are created but never destroyed because the
36*da0073e9SAndroid Build Coastguard Worker // destruction of global variables could happen after the CUDA runtime has
37*da0073e9SAndroid Build Coastguard Worker // already been destroyed and thus invoking cudaStreamDestroy could lead to a
38*da0073e9SAndroid Build Coastguard Worker // crash. It's likely an issue in CUDA, but to be safe - let's just "forget"
39*da0073e9SAndroid Build Coastguard Worker // the destruction.
40*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM)
41*da0073e9SAndroid Build Coastguard Worker // CUDA-only: used to initializes the stream pools (once)
42*da0073e9SAndroid Build Coastguard Worker static std::array<c10::once_flag, C10_COMPILE_TIME_MAX_GPUS> device_flags;
43*da0073e9SAndroid Build Coastguard Worker #endif
44*da0073e9SAndroid Build Coastguard Worker static std::array<
45*da0073e9SAndroid Build Coastguard Worker std::array<std::atomic<uint32_t>, C10_COMPILE_TIME_MAX_GPUS>,
46*da0073e9SAndroid Build Coastguard Worker c10::cuda::max_compile_time_stream_priorities>
47*da0073e9SAndroid Build Coastguard Worker priority_counters;
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker static std::array<
50*da0073e9SAndroid Build Coastguard Worker std::array<
51*da0073e9SAndroid Build Coastguard Worker std::array<cudaStream_t, kStreamsPerPool>,
52*da0073e9SAndroid Build Coastguard Worker C10_COMPILE_TIME_MAX_GPUS>,
53*da0073e9SAndroid Build Coastguard Worker c10::cuda::max_compile_time_stream_priorities>
54*da0073e9SAndroid Build Coastguard Worker streams;
55*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
56*da0073e9SAndroid Build Coastguard Worker static c10::once_flag
57*da0073e9SAndroid Build Coastguard Worker stream_flags[c10::cuda::max_compile_time_stream_priorities]
58*da0073e9SAndroid Build Coastguard Worker [C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
59*da0073e9SAndroid Build Coastguard Worker #endif
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker // Note [HIP Lazy Streams]
62*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~
63*da0073e9SAndroid Build Coastguard Worker // For ROCm/HIP, each stream is lazily initialized rather than creating all
64*da0073e9SAndroid Build Coastguard Worker // streams when the first stream is requested. HIP streams are not as
65*da0073e9SAndroid Build Coastguard Worker // lightweight as CUDA streams; the pooling strategy can affect performance.
66*da0073e9SAndroid Build Coastguard Worker // Rather than changing the pooling implementation, ROCm/HIP will lazy init
67*da0073e9SAndroid Build Coastguard Worker // each stream when it is first requested.
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker // Note [StreamId assignment]
70*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~~~~
71*da0073e9SAndroid Build Coastguard Worker // How do we assign stream IDs?
72*da0073e9SAndroid Build Coastguard Worker //
73*da0073e9SAndroid Build Coastguard Worker // -- 54 bits -- -- 5 bits ----- -- 4 bits -- --1 bit --
74*da0073e9SAndroid Build Coastguard Worker // zeros stream id index StreamIdType Ext/native stream
75*da0073e9SAndroid Build Coastguard Worker // ignored for ext ignored for ext
76*da0073e9SAndroid Build Coastguard Worker // for external stream, StreamID is a cudaStream_t pointer
77*da0073e9SAndroid Build Coastguard Worker // this means that last bit will always be 0
78*da0073e9SAndroid Build Coastguard Worker // so when constructing StreamId for a native stream we set last bit to 1
79*da0073e9SAndroid Build Coastguard Worker // to distinguish between native and external streams
80*da0073e9SAndroid Build Coastguard Worker //
81*da0073e9SAndroid Build Coastguard Worker //
82*da0073e9SAndroid Build Coastguard Worker // We are obligated to treat the stream ID 0 as the default stream, per the
83*da0073e9SAndroid Build Coastguard Worker // invariant specified in c10::Stream, so this is one exception to
84*da0073e9SAndroid Build Coastguard Worker // "last bit = 1 for native streams". However, all other numbers are entirely
85*da0073e9SAndroid Build Coastguard Worker // an internal implementation detail, we reserve the right to renumber streams
86*da0073e9SAndroid Build Coastguard Worker // however we like.
87*da0073e9SAndroid Build Coastguard Worker //
88*da0073e9SAndroid Build Coastguard Worker // Note that it is really important that the MSB is zero; StreamId is a
89*da0073e9SAndroid Build Coastguard Worker // *signed* integer, and unsigned to signed conversion outside of the
90*da0073e9SAndroid Build Coastguard Worker // bounds of signed integer representation is undefined behavior. You
91*da0073e9SAndroid Build Coastguard Worker // could work around this with something like
92*da0073e9SAndroid Build Coastguard Worker // https://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
93*da0073e9SAndroid Build Coastguard Worker // but it seems a bit overkill for this.
94*da0073e9SAndroid Build Coastguard Worker //
95*da0073e9SAndroid Build Coastguard Worker // Also, external managed stream pointers (cudaStream_t) can be directly stored
96*da0073e9SAndroid Build Coastguard Worker // in the Id field so in this case, we need to check the stream alignment.
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker class StreamIdType {
99*da0073e9SAndroid Build Coastguard Worker // StreamIdType encodes whether this stream is DEFAULT, EXTernal or
100*da0073e9SAndroid Build Coastguard Worker // for all other native streams, the stream priority (higher value is higher
101*da0073e9SAndroid Build Coastguard Worker // priority)
102*da0073e9SAndroid Build Coastguard Worker private:
103*da0073e9SAndroid Build Coastguard Worker uint8_t stream_type;
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker public:
106*da0073e9SAndroid Build Coastguard Worker static const uint8_t DEFAULT = 0x0;
107*da0073e9SAndroid Build Coastguard Worker static const uint8_t EXT = 0xF;
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker public:
StreamIdType(const uint8_t _stream_type)110*da0073e9SAndroid Build Coastguard Worker StreamIdType(const uint8_t _stream_type) : stream_type(_stream_type) {}
111*da0073e9SAndroid Build Coastguard Worker
isExt() const112*da0073e9SAndroid Build Coastguard Worker bool isExt() const {
113*da0073e9SAndroid Build Coastguard Worker return EXT == stream_type;
114*da0073e9SAndroid Build Coastguard Worker }
115*da0073e9SAndroid Build Coastguard Worker
isDefault() const116*da0073e9SAndroid Build Coastguard Worker bool isDefault() const {
117*da0073e9SAndroid Build Coastguard Worker return DEFAULT == stream_type;
118*da0073e9SAndroid Build Coastguard Worker }
119*da0073e9SAndroid Build Coastguard Worker
getStreamType() const120*da0073e9SAndroid Build Coastguard Worker uint8_t getStreamType() const {
121*da0073e9SAndroid Build Coastguard Worker return stream_type;
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker };
124*da0073e9SAndroid Build Coastguard Worker
operator <<(std::ostream & stream,StreamIdType s)125*da0073e9SAndroid Build Coastguard Worker std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
126*da0073e9SAndroid Build Coastguard Worker if (s.isDefault()) {
127*da0073e9SAndroid Build Coastguard Worker stream << "DEFAULT";
128*da0073e9SAndroid Build Coastguard Worker } else if (s.isExt()) {
129*da0073e9SAndroid Build Coastguard Worker stream << "EXT";
130*da0073e9SAndroid Build Coastguard Worker } else {
131*da0073e9SAndroid Build Coastguard Worker stream << "PRIORITY " << int(s.getStreamType());
132*da0073e9SAndroid Build Coastguard Worker }
133*da0073e9SAndroid Build Coastguard Worker return stream;
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker // StreamId is 64-bit, so we can just rely on regular promotion rules.
137*da0073e9SAndroid Build Coastguard Worker // We rely on streamIdIndex and streamIdType being non-negative;
138*da0073e9SAndroid Build Coastguard Worker // see Note [Hazard when concatenating signed integers]
139*da0073e9SAndroid Build Coastguard Worker
streamIdType(StreamId s)140*da0073e9SAndroid Build Coastguard Worker static inline StreamIdType streamIdType(StreamId s) {
141*da0073e9SAndroid Build Coastguard Worker // Externally allocated streams have their id being the cudaStream_ptr
142*da0073e9SAndroid Build Coastguard Worker // so the last bit will be 0
143*da0073e9SAndroid Build Coastguard Worker if ((!(s & 1)) && s) {
144*da0073e9SAndroid Build Coastguard Worker return StreamIdType(StreamIdType::EXT);
145*da0073e9SAndroid Build Coastguard Worker }
146*da0073e9SAndroid Build Coastguard Worker // last bit is external/internal stream, the mask should start from second
147*da0073e9SAndroid Build Coastguard Worker // rightmost bit
148*da0073e9SAndroid Build Coastguard Worker int mask_for_type = (1 << kStreamTypeBits) - 1;
149*da0073e9SAndroid Build Coastguard Worker auto val = (s >> 1) & mask_for_type;
150*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(val || !(s & 1), "invalid StreamId", s);
151*da0073e9SAndroid Build Coastguard Worker return StreamIdType(val);
152*da0073e9SAndroid Build Coastguard Worker }
153*da0073e9SAndroid Build Coastguard Worker
streamIdIndex(StreamId s)154*da0073e9SAndroid Build Coastguard Worker static inline size_t streamIdIndex(StreamId s) {
155*da0073e9SAndroid Build Coastguard Worker return static_cast<size_t>(
156*da0073e9SAndroid Build Coastguard Worker (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1));
157*da0073e9SAndroid Build Coastguard Worker }
158*da0073e9SAndroid Build Coastguard Worker
makeStreamId(StreamIdType st,size_t si)159*da0073e9SAndroid Build Coastguard Worker StreamId makeStreamId(StreamIdType st, size_t si) {
160*da0073e9SAndroid Build Coastguard Worker if (st.isDefault()) {
161*da0073e9SAndroid Build Coastguard Worker return static_cast<StreamId>(0);
162*da0073e9SAndroid Build Coastguard Worker }
163*da0073e9SAndroid Build Coastguard Worker return (static_cast<StreamId>(si) << (kStreamTypeBits + 1)) |
164*da0073e9SAndroid Build Coastguard Worker static_cast<StreamId>(st.getStreamType() << 1) | 1;
165*da0073e9SAndroid Build Coastguard Worker }
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker // Thread-local current streams
168*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-arrays)
169*da0073e9SAndroid Build Coastguard Worker static thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker // Populates global values.
172*da0073e9SAndroid Build Coastguard Worker // Warning: this function must only be called once!
initGlobalStreamState()173*da0073e9SAndroid Build Coastguard Worker static void initGlobalStreamState() {
174*da0073e9SAndroid Build Coastguard Worker num_gpus = device_count();
175*da0073e9SAndroid Build Coastguard Worker // Check if the number of GPUs matches the expected compile-time max number
176*da0073e9SAndroid Build Coastguard Worker // of GPUs.
177*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
178*da0073e9SAndroid Build Coastguard Worker num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
179*da0073e9SAndroid Build Coastguard Worker "Number of CUDA devices on the machine is larger than the compiled "
180*da0073e9SAndroid Build Coastguard Worker "max number of gpus expected (",
181*da0073e9SAndroid Build Coastguard Worker C10_COMPILE_TIME_MAX_GPUS,
182*da0073e9SAndroid Build Coastguard Worker "). Increase that and recompile.");
183*da0073e9SAndroid Build Coastguard Worker int leastPriority = -1, greatestPriority = -1;
184*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(
185*da0073e9SAndroid Build Coastguard Worker cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority));
186*da0073e9SAndroid Build Coastguard Worker // Note [HIP stream priorities]
187*da0073e9SAndroid Build Coastguard Worker // HIP stream priorities are 1=low, 0=default, -1=high which differs from CUDA
188*da0073e9SAndroid Build Coastguard Worker // which is 0=default, -1=high, -2=higher etc.
189*da0073e9SAndroid Build Coastguard Worker // Clamp leastPriority to 0 for HIP.
190*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
191*da0073e9SAndroid Build Coastguard Worker leastPriority = 0;
192*da0073e9SAndroid Build Coastguard Worker #endif
193*da0073e9SAndroid Build Coastguard Worker // greatestPriority is negative
194*da0073e9SAndroid Build Coastguard Worker auto range = leastPriority - greatestPriority + 1;
195*da0073e9SAndroid Build Coastguard Worker max_stream_priorities = range >= c10::cuda::max_compile_time_stream_priorities
196*da0073e9SAndroid Build Coastguard Worker ? c10::cuda::max_compile_time_stream_priorities
197*da0073e9SAndroid Build Coastguard Worker : range;
198*da0073e9SAndroid Build Coastguard Worker }
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker // Init a single CUDA or HIP stream
201*da0073e9SAndroid Build Coastguard Worker // See Note [HIP Lazy Streams]
initSingleStream(int p,DeviceIndex device_index,int i)202*da0073e9SAndroid Build Coastguard Worker static void initSingleStream(int p, DeviceIndex device_index, int i) {
203*da0073e9SAndroid Build Coastguard Worker auto& stream = streams[p][device_index][i];
204*da0073e9SAndroid Build Coastguard Worker auto pri = -p; // lower number is higher priority
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
207*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
208*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
209*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_stream_creation(
210*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
211*da0073e9SAndroid Build Coastguard Worker priority_counters[p][device_index] = 0;
212*da0073e9SAndroid Build Coastguard Worker }
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker // Creates the low and high priority stream pools for the specified device
216*da0073e9SAndroid Build Coastguard Worker // Warning: only call once per device!
initDeviceStreamState(DeviceIndex device_index)217*da0073e9SAndroid Build Coastguard Worker static void initDeviceStreamState(DeviceIndex device_index) {
218*da0073e9SAndroid Build Coastguard Worker // Switches to the requested device so streams are properly associated
219*da0073e9SAndroid Build Coastguard Worker // with it.
220*da0073e9SAndroid Build Coastguard Worker CUDAGuard device_guard{device_index};
221*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(kStreamsPerPool)) {
222*da0073e9SAndroid Build Coastguard Worker for (const auto p : c10::irange(max_stream_priorities)) {
223*da0073e9SAndroid Build Coastguard Worker initSingleStream(p, device_index, i);
224*da0073e9SAndroid Build Coastguard Worker }
225*da0073e9SAndroid Build Coastguard Worker }
226*da0073e9SAndroid Build Coastguard Worker }
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker // Init front-end to ensure initialization only occurs once
initCUDAStreamsOnce()229*da0073e9SAndroid Build Coastguard Worker static void initCUDAStreamsOnce() {
230*da0073e9SAndroid Build Coastguard Worker // Inits default streams (once, globally)
231*da0073e9SAndroid Build Coastguard Worker c10::call_once(init_flag, initGlobalStreamState);
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker if (current_streams) {
234*da0073e9SAndroid Build Coastguard Worker return;
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker // Inits current streams (thread local) to default streams
238*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-arrays)
239*da0073e9SAndroid Build Coastguard Worker current_streams = std::make_unique<StreamId[]>(num_gpus);
240*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(num_gpus)) {
241*da0073e9SAndroid Build Coastguard Worker current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0);
242*da0073e9SAndroid Build Coastguard Worker }
243*da0073e9SAndroid Build Coastguard Worker }
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker // Helper to verify the GPU index is valid
check_gpu(DeviceIndex device_index)246*da0073e9SAndroid Build Coastguard Worker static inline void check_gpu(DeviceIndex device_index) {
247*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(device_index >= 0 && device_index < num_gpus);
248*da0073e9SAndroid Build Coastguard Worker }
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker // Helper to determine the index of the stream to return
251*da0073e9SAndroid Build Coastguard Worker // Note: Streams are returned round-robin (see note in CUDAStream.h)
get_idx(std::atomic<uint32_t> & counter)252*da0073e9SAndroid Build Coastguard Worker static uint32_t get_idx(std::atomic<uint32_t>& counter) {
253*da0073e9SAndroid Build Coastguard Worker auto raw_idx = counter++;
254*da0073e9SAndroid Build Coastguard Worker return raw_idx % kStreamsPerPool;
255*da0073e9SAndroid Build Coastguard Worker }
256*da0073e9SAndroid Build Coastguard Worker
CUDAStreamForId(DeviceIndex device_index,StreamId stream_id)257*da0073e9SAndroid Build Coastguard Worker CUDAStream CUDAStreamForId(DeviceIndex device_index, StreamId stream_id) {
258*da0073e9SAndroid Build Coastguard Worker return CUDAStream(
259*da0073e9SAndroid Build Coastguard Worker CUDAStream::UNCHECKED,
260*da0073e9SAndroid Build Coastguard Worker Stream(
261*da0073e9SAndroid Build Coastguard Worker Stream::UNSAFE,
262*da0073e9SAndroid Build Coastguard Worker c10::Device(DeviceType::CUDA, device_index),
263*da0073e9SAndroid Build Coastguard Worker stream_id));
264*da0073e9SAndroid Build Coastguard Worker }
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker // See Note [StreamId assignment]
stream() const269*da0073e9SAndroid Build Coastguard Worker cudaStream_t CUDAStream::stream() const {
270*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device_index = stream_.device_index();
271*da0073e9SAndroid Build Coastguard Worker StreamId stream_id = stream_.id();
272*da0073e9SAndroid Build Coastguard Worker StreamIdType st = streamIdType(stream_id);
273*da0073e9SAndroid Build Coastguard Worker size_t si = streamIdIndex(stream_id);
274*da0073e9SAndroid Build Coastguard Worker if (st.isDefault()) {
275*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
276*da0073e9SAndroid Build Coastguard Worker si == 0,
277*da0073e9SAndroid Build Coastguard Worker "Unrecognized stream ",
278*da0073e9SAndroid Build Coastguard Worker stream_,
279*da0073e9SAndroid Build Coastguard Worker " (I think this should be the default stream, but I got a non-zero index ",
280*da0073e9SAndroid Build Coastguard Worker si,
281*da0073e9SAndroid Build Coastguard Worker ").",
282*da0073e9SAndroid Build Coastguard Worker " Did you manufacture the StreamId yourself? Don't do that; use the",
283*da0073e9SAndroid Build Coastguard Worker " official API like c10::cuda::getStreamFromPool() to get a new stream.");
284*da0073e9SAndroid Build Coastguard Worker return nullptr;
285*da0073e9SAndroid Build Coastguard Worker } else if (st.isExt()) {
286*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(performance-no-int-to-ptr)
287*da0073e9SAndroid Build Coastguard Worker return reinterpret_cast<cudaStream_t>(stream_id);
288*da0073e9SAndroid Build Coastguard Worker } else {
289*da0073e9SAndroid Build Coastguard Worker auto streamType = st.getStreamType();
290*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
291*da0073e9SAndroid Build Coastguard Worker streamType >= 1 && streamType <= max_stream_priorities,
292*da0073e9SAndroid Build Coastguard Worker "Unrecognized stream ",
293*da0073e9SAndroid Build Coastguard Worker stream_,
294*da0073e9SAndroid Build Coastguard Worker " (I didn't recognize the stream type, ",
295*da0073e9SAndroid Build Coastguard Worker st,
296*da0073e9SAndroid Build Coastguard Worker " with the value ",
297*da0073e9SAndroid Build Coastguard Worker streamType,
298*da0073e9SAndroid Build Coastguard Worker ")");
299*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
300*da0073e9SAndroid Build Coastguard Worker // See Note [HIP Lazy Streams]
301*da0073e9SAndroid Build Coastguard Worker c10::call_once(
302*da0073e9SAndroid Build Coastguard Worker stream_flags[st.getStreamType() - 1][device_index][si],
303*da0073e9SAndroid Build Coastguard Worker initSingleStream,
304*da0073e9SAndroid Build Coastguard Worker st.getStreamType() - 1,
305*da0073e9SAndroid Build Coastguard Worker device_index,
306*da0073e9SAndroid Build Coastguard Worker si);
307*da0073e9SAndroid Build Coastguard Worker #endif
308*da0073e9SAndroid Build Coastguard Worker return streams[st.getStreamType() - 1][device_index][si];
309*da0073e9SAndroid Build Coastguard Worker }
310*da0073e9SAndroid Build Coastguard Worker }
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker // Returns a stream from the requested pool
313*da0073e9SAndroid Build Coastguard Worker // Note: when called the first time on a device, this will create the
314*da0073e9SAndroid Build Coastguard Worker // stream pools for that device.
getStreamFromPool(const int priority,DeviceIndex device_index)315*da0073e9SAndroid Build Coastguard Worker CUDAStream getStreamFromPool(const int priority, DeviceIndex device_index) {
316*da0073e9SAndroid Build Coastguard Worker initCUDAStreamsOnce();
317*da0073e9SAndroid Build Coastguard Worker if (device_index == -1) {
318*da0073e9SAndroid Build Coastguard Worker device_index = current_device();
319*da0073e9SAndroid Build Coastguard Worker c10::cuda::SetTargetDevice();
320*da0073e9SAndroid Build Coastguard Worker }
321*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
322*da0073e9SAndroid Build Coastguard Worker priority <= 0,
323*da0073e9SAndroid Build Coastguard Worker "Expected cuda stream priority to be less than or equal to 0, got ",
324*da0073e9SAndroid Build Coastguard Worker priority);
325*da0073e9SAndroid Build Coastguard Worker check_gpu(device_index);
326*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM)
327*da0073e9SAndroid Build Coastguard Worker // See Note [HIP Lazy Streams]
328*da0073e9SAndroid Build Coastguard Worker // CUDA-only: Initializes the stream pools (once)
329*da0073e9SAndroid Build Coastguard Worker c10::call_once(
330*da0073e9SAndroid Build Coastguard Worker device_flags[device_index], initDeviceStreamState, device_index);
331*da0073e9SAndroid Build Coastguard Worker #endif
332*da0073e9SAndroid Build Coastguard Worker auto pri_idx = -priority;
333*da0073e9SAndroid Build Coastguard Worker pri_idx =
334*da0073e9SAndroid Build Coastguard Worker std::min(pri_idx, max_stream_priorities - 1); // pri_idx is zero-based
335*da0073e9SAndroid Build Coastguard Worker const auto idx = get_idx(priority_counters[pri_idx][device_index]);
336*da0073e9SAndroid Build Coastguard Worker StreamIdType id_type = StreamIdType(pri_idx + 1);
337*da0073e9SAndroid Build Coastguard Worker return CUDAStreamForId(device_index, makeStreamId(id_type, idx));
338*da0073e9SAndroid Build Coastguard Worker }
339*da0073e9SAndroid Build Coastguard Worker
getStreamFromPool(const bool isHighPriority,DeviceIndex device)340*da0073e9SAndroid Build Coastguard Worker CUDAStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) {
341*da0073e9SAndroid Build Coastguard Worker initCUDAStreamsOnce();
342*da0073e9SAndroid Build Coastguard Worker int priority = isHighPriority ? -max_stream_priorities + 1 : 0;
343*da0073e9SAndroid Build Coastguard Worker return getStreamFromPool(priority, device);
344*da0073e9SAndroid Build Coastguard Worker }
345*da0073e9SAndroid Build Coastguard Worker
getStreamFromExternal(cudaStream_t ext_stream,DeviceIndex device_index)346*da0073e9SAndroid Build Coastguard Worker CUDAStream getStreamFromExternal(
347*da0073e9SAndroid Build Coastguard Worker cudaStream_t ext_stream,
348*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_index) {
349*da0073e9SAndroid Build Coastguard Worker // The stream pointer will be the actual id
350*da0073e9SAndroid Build Coastguard Worker return CUDAStreamForId(device_index, reinterpret_cast<int64_t>(ext_stream));
351*da0073e9SAndroid Build Coastguard Worker }
352*da0073e9SAndroid Build Coastguard Worker
getDefaultCUDAStream(DeviceIndex device_index)353*da0073e9SAndroid Build Coastguard Worker CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
354*da0073e9SAndroid Build Coastguard Worker initCUDAStreamsOnce();
355*da0073e9SAndroid Build Coastguard Worker if (device_index == -1) {
356*da0073e9SAndroid Build Coastguard Worker device_index = current_device();
357*da0073e9SAndroid Build Coastguard Worker c10::cuda::SetTargetDevice();
358*da0073e9SAndroid Build Coastguard Worker }
359*da0073e9SAndroid Build Coastguard Worker check_gpu(device_index);
360*da0073e9SAndroid Build Coastguard Worker return CUDAStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0));
361*da0073e9SAndroid Build Coastguard Worker }
362*da0073e9SAndroid Build Coastguard Worker
getCurrentCUDAStream(DeviceIndex device_index)363*da0073e9SAndroid Build Coastguard Worker CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
364*da0073e9SAndroid Build Coastguard Worker initCUDAStreamsOnce();
365*da0073e9SAndroid Build Coastguard Worker if (device_index == -1) {
366*da0073e9SAndroid Build Coastguard Worker device_index = current_device();
367*da0073e9SAndroid Build Coastguard Worker c10::cuda::SetTargetDevice();
368*da0073e9SAndroid Build Coastguard Worker }
369*da0073e9SAndroid Build Coastguard Worker check_gpu(device_index);
370*da0073e9SAndroid Build Coastguard Worker return CUDAStreamForId(device_index, current_streams[device_index]);
371*da0073e9SAndroid Build Coastguard Worker }
372*da0073e9SAndroid Build Coastguard Worker
setCurrentCUDAStream(CUDAStream stream)373*da0073e9SAndroid Build Coastguard Worker void setCurrentCUDAStream(CUDAStream stream) {
374*da0073e9SAndroid Build Coastguard Worker initCUDAStreamsOnce();
375*da0073e9SAndroid Build Coastguard Worker current_streams[stream.device_index()] = stream.id();
376*da0073e9SAndroid Build Coastguard Worker }
377*da0073e9SAndroid Build Coastguard Worker
operator <<(std::ostream & stream,const CUDAStream & s)378*da0073e9SAndroid Build Coastguard Worker std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
379*da0073e9SAndroid Build Coastguard Worker return stream << s.unwrap();
380*da0073e9SAndroid Build Coastguard Worker }
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
383