xref: /aosp_15_r20/external/pytorch/aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/hip/HIPStream.h>
4 
5 // Use of c10::hip namespace here makes hipification easier, because
6 // I don't have to also fix namespaces.  Sorry!
7 namespace c10 { namespace hip {
8 
9 // See Note [Masquerading as CUDA] for motivation
10 
11 class HIPStreamMasqueradingAsCUDA {
12 public:
13 
14   enum Unchecked { UNCHECKED };
15 
HIPStreamMasqueradingAsCUDA(Stream stream)16   explicit HIPStreamMasqueradingAsCUDA(Stream stream)
17     : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
18     // We did the coercion unchecked; check that it was right.
19     TORCH_CHECK(stream.device().is_cuda() /* !!! */);
20   }
21 
HIPStreamMasqueradingAsCUDA(Unchecked,Stream stream)22   explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
23     // Unsafely coerce the "CUDA" stream into a HIP stream
24     : stream_(
25         HIPStream(
26           Stream(
27             Stream::UNSAFE,
28             Device(c10::DeviceType::HIP, stream.device_index()),
29             stream.id())
30         )
31       ) {}
32 
33   // New constructor, just for this.  Does NOT coerce.
HIPStreamMasqueradingAsCUDA(HIPStream stream)34   explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
35 
36   bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
37     return stream_ == other.stream_;
38   }
39 
40   bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
41     return stream_ != other.stream_;
42   }
43 
hipStream_t()44   operator hipStream_t() const { return stream_.stream(); }
45 
Stream()46   operator Stream() const {
47     // Unsafely coerce HIP stream into a "CUDA" stream
48     return Stream(Stream::UNSAFE, device(), id());
49   }
50 
device_index()51   DeviceIndex device_index() const { return stream_.device_index(); }
52 
53   // Unsafely coerce HIP device into CUDA device
device_type()54   c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
55 
device()56   Device device() const {
57     // Unsafely coerce HIP device into CUDA device
58     return Device(c10::DeviceType::CUDA, stream_.device_index());
59   }
60 
id()61   StreamId id() const        { return stream_.id(); }
query()62   bool query() const         { return stream_.query(); }
synchronize()63   void synchronize() const   { stream_.synchronize(); }
priority()64   int priority() const       { return stream_.priority(); }
stream()65   hipStream_t stream() const { return stream_.stream(); }
66 
unwrap()67   Stream unwrap() const {
68     // Unsafely coerce HIP stream into "CUDA" stream
69     return Stream(Stream::UNSAFE, device(), id());
70   }
71 
pack3()72   c10::StreamData3 pack3() const noexcept {
73     // Unsafely coerce HIP stream into "CUDA" stream before packing
74     return unwrap().pack3();
75   }
76 
unpack3(StreamId stream_id,DeviceIndex device_index,c10::DeviceType device_type)77   static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
78                                              DeviceIndex device_index,
79                                              c10::DeviceType device_type) {
80     // NB: constructor manages CUDA->HIP translation for us
81     return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
82         stream_id, device_index, device_type));
83   }
84 
priority_range()85   static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
86 
87   // New method, gets the underlying HIPStream
hip_stream()88   HIPStream hip_stream() const { return stream_; }
89 
90 private:
91   HIPStream stream_;
92 };
93 
94 HIPStreamMasqueradingAsCUDA
95 inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
96   return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
97 }
98 
99 HIPStreamMasqueradingAsCUDA
100 inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) {
101   return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device));
102 }
103 
104 HIPStreamMasqueradingAsCUDA
getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream,DeviceIndex device)105 inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
106   return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
107 }
108 
109 inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
110   return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
111 }
112 
113 inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
114   return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
115 }
116 
setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream)117 inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
118   setCurrentHIPStream(stream.hip_stream());
119 }
120 
121 inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
122   stream << s.hip_stream() << " (masquerading as CUDA)";
123   return stream;
124 }
125 
126 }} // namespace c10::hip
127 
128 namespace std {
129   template <>
130   struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
131     size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
132       return std::hash<c10::Stream>{}(s.unwrap());
133     }
134   };
135 } // namespace std
136