1 #pragma once
2
3 #include <c10/hip/HIPStream.h>
4
5 // Use of c10::hip namespace here makes hipification easier, because
6 // I don't have to also fix namespaces. Sorry!
7 namespace c10 { namespace hip {
8
9 // See Note [Masquerading as CUDA] for motivation
10
11 class HIPStreamMasqueradingAsCUDA {
12 public:
13
14 enum Unchecked { UNCHECKED };
15
HIPStreamMasqueradingAsCUDA(Stream stream)16 explicit HIPStreamMasqueradingAsCUDA(Stream stream)
17 : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
18 // We did the coercion unchecked; check that it was right.
19 TORCH_CHECK(stream.device().is_cuda() /* !!! */);
20 }
21
HIPStreamMasqueradingAsCUDA(Unchecked,Stream stream)22 explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
23 // Unsafely coerce the "CUDA" stream into a HIP stream
24 : stream_(
25 HIPStream(
26 Stream(
27 Stream::UNSAFE,
28 Device(c10::DeviceType::HIP, stream.device_index()),
29 stream.id())
30 )
31 ) {}
32
33 // New constructor, just for this. Does NOT coerce.
HIPStreamMasqueradingAsCUDA(HIPStream stream)34 explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
35
36 bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
37 return stream_ == other.stream_;
38 }
39
40 bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
41 return stream_ != other.stream_;
42 }
43
hipStream_t()44 operator hipStream_t() const { return stream_.stream(); }
45
Stream()46 operator Stream() const {
47 // Unsafely coerce HIP stream into a "CUDA" stream
48 return Stream(Stream::UNSAFE, device(), id());
49 }
50
device_index()51 DeviceIndex device_index() const { return stream_.device_index(); }
52
53 // Unsafely coerce HIP device into CUDA device
device_type()54 c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
55
device()56 Device device() const {
57 // Unsafely coerce HIP device into CUDA device
58 return Device(c10::DeviceType::CUDA, stream_.device_index());
59 }
60
id()61 StreamId id() const { return stream_.id(); }
query()62 bool query() const { return stream_.query(); }
synchronize()63 void synchronize() const { stream_.synchronize(); }
priority()64 int priority() const { return stream_.priority(); }
stream()65 hipStream_t stream() const { return stream_.stream(); }
66
unwrap()67 Stream unwrap() const {
68 // Unsafely coerce HIP stream into "CUDA" stream
69 return Stream(Stream::UNSAFE, device(), id());
70 }
71
pack3()72 c10::StreamData3 pack3() const noexcept {
73 // Unsafely coerce HIP stream into "CUDA" stream before packing
74 return unwrap().pack3();
75 }
76
unpack3(StreamId stream_id,DeviceIndex device_index,c10::DeviceType device_type)77 static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
78 DeviceIndex device_index,
79 c10::DeviceType device_type) {
80 // NB: constructor manages CUDA->HIP translation for us
81 return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
82 stream_id, device_index, device_type));
83 }
84
priority_range()85 static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
86
87 // New method, gets the underlying HIPStream
hip_stream()88 HIPStream hip_stream() const { return stream_; }
89
90 private:
91 HIPStream stream_;
92 };
93
94 HIPStreamMasqueradingAsCUDA
95 inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
96 return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
97 }
98
99 HIPStreamMasqueradingAsCUDA
100 inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) {
101 return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device));
102 }
103
104 HIPStreamMasqueradingAsCUDA
getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream,DeviceIndex device)105 inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
106 return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
107 }
108
109 inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
110 return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
111 }
112
113 inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
114 return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
115 }
116
setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream)117 inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
118 setCurrentHIPStream(stream.hip_stream());
119 }
120
121 inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
122 stream << s.hip_stream() << " (masquerading as CUDA)";
123 return stream;
124 }
125
126 }} // namespace c10::hip
127
128 namespace std {
129 template <>
130 struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
131 size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
132 return std::hash<c10::Stream>{}(s.unwrap());
133 }
134 };
135 } // namespace std
136