xref: /aosp_15_r20/external/pytorch/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/hip/HIPConfig.h>
4 
5 // The includes of HIPGuard.h
6 #include <c10/hip/impl/HIPGuardImpl.h>
7 #include <c10/hip/HIPMacros.h>
8 #include <c10/core/DeviceType.h>
9 #include <c10/core/impl/InlineDeviceGuard.h>
10 #include <c10/core/impl/InlineStreamGuard.h>
11 #include <c10/util/Exception.h>
12 
13 #include <c10/hip/impl/HIPGuardImpl.h>
14 
15 #include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
16 #include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
17 
18 // Use of c10::hip namespace here makes hipification easier, because
19 // I don't have to also fix namespaces.  Sorry!
20 namespace c10 { namespace hip {
21 
22 // Note [Masquerading as CUDA]
23 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~
24 // c10_hip is very easy to understand: it is HIPified from c10_cuda,
25 // and anywhere you said CUDA, the source code now says HIP.  HIPified
26 // PyTorch is much harder to understand: it is HIPified from regular
27 // PyTorch, yes, but NO source-to-source translation from CUDA to
28 // HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
29 // For example, when you use HIPified PyTorch, you say x.cuda() to
30 // move a tensor onto ROCm device.  We call this situation "HIP
31 // masquerading as CUDA".
32 //
33 // This leads to a very awkward situation when we want to call c10_hip
34 // code from PyTorch, since c10_hip is expecting things to be called
35 // HIP, but PyTorch is calling them CUDA (masquerading as HIP).  To
36 // fix this impedance mismatch, we have MasqueradingAsCUDA variants
37 // for all c10_hip classes.  These translate between the "HIP" and "CUDA
38 // masquerading as HIP" worlds.  For example,
39 // HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
40 // HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
41 // returns CUDA, getDevice() reports the current HIP device as a CUDA
42 // device.)
43 //
44 // We should be able to delete all of these classes entirely once
45 // we switch PyTorch to calling a HIP a HIP.
46 //
47 // When you add a new MasqueradingAsCUDA class/function, you need to
48 // also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
49 //
50 //
51 //
52 // By the way, note that the cpp file associated with this also
53 // *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
54 // this HIP implementation.
55 
56 struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
57   static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA;
HIPGuardImplMasqueradingAsCUDAfinal58   HIPGuardImplMasqueradingAsCUDA() {}
HIPGuardImplMasqueradingAsCUDAfinal59   HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) {
60     TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA);
61   }
typefinal62   c10::DeviceType type() const override {
63     return c10::DeviceType::CUDA;
64   }
exchangeDevicefinal65   Device exchangeDevice(Device d) const override {
66     TORCH_INTERNAL_ASSERT(d.is_cuda());
67     Device old_device = getDevice();
68     if (old_device.index() != d.index()) {
69       C10_HIP_CHECK(hipSetDevice(d.index()));
70     }
71     return old_device;
72   }
getDevicefinal73   Device getDevice() const override {
74     int device;
75     C10_HIP_CHECK(hipGetDevice(&device));
76     return Device(c10::DeviceType::CUDA, device);
77   }
setDevicefinal78   void setDevice(Device d) const override {
79     TORCH_INTERNAL_ASSERT(d.is_cuda());
80     C10_HIP_CHECK(hipSetDevice(d.index()));
81   }
uncheckedSetDevicefinal82   void uncheckedSetDevice(Device d) const noexcept override {
83     C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
84   }
getStreamfinal85   Stream getStream(Device d) const noexcept override {
86     return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
87   }
getDefaultStreamfinal88   Stream getDefaultStream(Device d) const override {
89     return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
90   }
91   Stream getNewStream(Device d, int priority = 0) const override {
92     return getStreamFromPoolMasqueradingAsCUDA(priority, d.index());
93   }
94   Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
95     return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
96   }
exchangeStreamfinal97   Stream exchangeStream(Stream s) const noexcept override {
98     HIPStreamMasqueradingAsCUDA cs(s);
99     auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
100     setCurrentHIPStreamMasqueradingAsCUDA(cs);
101     return old_stream.unwrap();
102   }
deviceCountfinal103   DeviceIndex deviceCount() const noexcept override {
104     int deviceCnt;
105     hipError_t _err;
106     _err = hipGetDeviceCount(&deviceCnt);
107     if(_err != hipErrorNoDevice && _err != hipSuccess)
108         C10_HIP_CHECK(_err);
109     return deviceCnt;
110   }
111 
112   // Event-related functions
113   // Note: hipEventCreateWithFlags should be called on the same device as
114   //  the recording stream's device.
createEventfinal115   void createEvent(
116     hipEvent_t* hip_event,
117     const EventFlag flag) const {
118     // Maps PyTorch's Event::Flag to HIP flag
119     auto hip_flag = hipEventDefault;
120     switch (flag) {
121       case EventFlag::PYTORCH_DEFAULT:
122         hip_flag = hipEventDisableTiming;
123         break;
124       case EventFlag::BACKEND_DEFAULT:
125         hip_flag = hipEventDefault;
126         break;
127       default:
128         TORCH_CHECK(false, "HIP event received unknown flag");
129     }
130 
131     C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
132   }
133 
destroyEventfinal134   void destroyEvent(
135     void* event,
136     const DeviceIndex device_index) const noexcept override {
137     if (!event) return;
138     auto hip_event = static_cast<hipEvent_t>(event);
139     int orig_device;
140     C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
141     C10_HIP_CHECK_WARN(hipSetDevice(device_index));
142     C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
143     C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
144   }
145 
recordfinal146   void record(void** event,
147     const Stream& stream,
148     const DeviceIndex device_index,
149     const EventFlag flag) const override {
150     TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
151       "Event device index ",
152       device_index,
153       " does not match recording stream's device index ",
154       stream.device_index(),
155       ".");
156 
157     hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
158     HIPStreamMasqueradingAsCUDA hip_stream{stream};
159 
160     // Moves to stream's device to record
161     const auto orig_device = getDevice();
162     setDevice(stream.device());
163 
164     // Creates the event (lazily)
165     if (!hip_event) createEvent(&hip_event, flag);
166     C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
167     // Makes the void* point to the (possibly just allocated) HIP event
168     *event = hip_event;
169 
170     // Resets device
171     setDevice(orig_device);
172   }
173 
blockfinal174   void block(
175     void* event,
176     const Stream& stream) const override {
177     if (!event) return;
178     hipEvent_t hip_event = static_cast<hipEvent_t>(event);
179     HIPStreamMasqueradingAsCUDA hip_stream{stream};
180     const auto orig_device = getDevice();
181     setDevice(stream.device());
182     C10_HIP_CHECK(hipStreamWaitEvent(
183       hip_stream,
184       hip_event,
185       /*flags (must be zero)=*/ 0));
186     setDevice(orig_device);
187   }
188 
queryEventfinal189   bool queryEvent(void* event) const override {
190     if (!event) return true;
191     hipEvent_t hip_event = static_cast<hipEvent_t>(event);
192     const hipError_t err = hipEventQuery(hip_event);
193     if (err != hipErrorNotReady) C10_HIP_CHECK(err);
194     else {
195       // ignore and clear the error if not ready
196       (void)hipGetLastError();
197     }
198     return (err == hipSuccess);
199   }
200 
201   // Stream-related functions
queryStreamfinal202   bool queryStream(const Stream& stream) const override {
203     HIPStreamMasqueradingAsCUDA hip_stream{stream};
204     return hip_stream.query();
205   }
206 
synchronizeStreamfinal207   void synchronizeStream(const Stream& stream) const override {
208     HIPStreamMasqueradingAsCUDA hip_stream{stream};
209     hip_stream.synchronize();
210   }
211 
synchronizeEventfinal212   void synchronizeEvent(void* event) const override {
213     if (!event)
214       return;
215     hipEvent_t hip_event = static_cast<hipEvent_t>(event);
216     C10_HIP_CHECK(hipEventSynchronize(hip_event));
217   }
218 
recordDataPtrOnStreamfinal219   void recordDataPtrOnStream(
220     const c10::DataPtr& data_ptr,
221     const Stream& stream) const override {
222     HIPStreamMasqueradingAsCUDA hip_stream{stream};
223     HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
224   }
225 
elapsedTimefinal226   double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
227       const override {
228     TORCH_CHECK(
229         event1 && event2,
230         "Both events must be recorded before calculating elapsed time.");
231     int orig_device;
232     C10_HIP_CHECK(hipGetDevice(&orig_device));
233     C10_HIP_CHECK(hipSetDevice(device_index));
234     hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1);
235     hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2);
236     float time_ms = 0;
237     // raise hipErrorNotReady if either event is recorded but not yet completed
238     C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2));
239     C10_HIP_CHECK(hipSetDevice(orig_device));
240     return static_cast<double>(time_ms);
241   }
242 };
243 
244 // All of the guards which have HIPGuardImpl burned in need to also have
245 // variants using HIPGuardImplMasqueradingAsCUDA.
246 
247 /// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
248 /// the correct InlineDeviceGuard burned in.  Sorry about the
249 /// copy-pasting.
250 
251 struct HIPGuardMasqueradingAsCUDA {
252   explicit HIPGuardMasqueradingAsCUDA() = delete;
HIPGuardMasqueradingAsCUDAHIPGuardMasqueradingAsCUDA253   explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
HIPGuardMasqueradingAsCUDAHIPGuardMasqueradingAsCUDA254   explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
255 
256   HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
257   HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
258   HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
259   HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
260 
set_deviceHIPGuardMasqueradingAsCUDA261   void set_device(Device device) { guard_.set_device(device); }
reset_deviceHIPGuardMasqueradingAsCUDA262   void reset_device(Device device) { guard_.reset_device(device); }
set_indexHIPGuardMasqueradingAsCUDA263   void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
original_deviceHIPGuardMasqueradingAsCUDA264   Device original_device() const { return guard_.original_device(); }
current_deviceHIPGuardMasqueradingAsCUDA265   Device current_device() const { return guard_.current_device(); }
266 
267  private:
268   c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
269 };
270 
271 struct OptionalHIPGuardMasqueradingAsCUDA {
OptionalHIPGuardMasqueradingAsCUDAOptionalHIPGuardMasqueradingAsCUDA272   explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
OptionalHIPGuardMasqueradingAsCUDAOptionalHIPGuardMasqueradingAsCUDA273   explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<Device> device_opt) : guard_(device_opt) {}
OptionalHIPGuardMasqueradingAsCUDAOptionalHIPGuardMasqueradingAsCUDA274   explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
275 
276   OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
277   OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
278   OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
279   OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
280 
set_deviceOptionalHIPGuardMasqueradingAsCUDA281   void set_device(Device device) { guard_.set_device(device); }
reset_deviceOptionalHIPGuardMasqueradingAsCUDA282   void reset_device(Device device) { guard_.reset_device(device); }
set_indexOptionalHIPGuardMasqueradingAsCUDA283   void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
original_deviceOptionalHIPGuardMasqueradingAsCUDA284   std::optional<Device> original_device() const { return guard_.original_device(); }
current_deviceOptionalHIPGuardMasqueradingAsCUDA285   std::optional<Device> current_device() const { return guard_.current_device(); }
resetOptionalHIPGuardMasqueradingAsCUDA286   void reset() { guard_.reset(); }
287 
288 private:
289   c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
290 };
291 
292 struct HIPStreamGuardMasqueradingAsCUDA {
293   explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
HIPStreamGuardMasqueradingAsCUDAHIPStreamGuardMasqueradingAsCUDA294   explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
295   HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
296   HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
297   HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
298   HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
299 
reset_streamHIPStreamGuardMasqueradingAsCUDA300   void reset_stream(Stream stream) { guard_.reset_stream(stream); }
301 
original_streamHIPStreamGuardMasqueradingAsCUDA302   HIPStreamMasqueradingAsCUDA original_stream() const {
303     return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
304   }
current_streamHIPStreamGuardMasqueradingAsCUDA305   HIPStreamMasqueradingAsCUDA current_stream() const {
306     return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
307   }
308 
current_deviceHIPStreamGuardMasqueradingAsCUDA309   Device current_device() const { return guard_.current_device(); }
original_deviceHIPStreamGuardMasqueradingAsCUDA310   Device original_device() const { return guard_.original_device(); }
311 
312 private:
313   c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
314 };
315 
316 struct OptionalHIPStreamGuardMasqueradingAsCUDA {
OptionalHIPStreamGuardMasqueradingAsCUDAOptionalHIPStreamGuardMasqueradingAsCUDA317   explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
OptionalHIPStreamGuardMasqueradingAsCUDAOptionalHIPStreamGuardMasqueradingAsCUDA318   explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
OptionalHIPStreamGuardMasqueradingAsCUDAOptionalHIPStreamGuardMasqueradingAsCUDA319   explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional<Stream> stream_opt) : guard_(stream_opt) {}
320 
321   OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
322   OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
323   OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
324   OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
325 
reset_streamOptionalHIPStreamGuardMasqueradingAsCUDA326   void reset_stream(Stream stream) { guard_.reset_stream(stream); }
327 
original_streamOptionalHIPStreamGuardMasqueradingAsCUDA328   std::optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
329     auto r = guard_.original_stream();
330     if (r.has_value()) {
331       return std::make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
332     } else {
333       return std::nullopt;
334     }
335   }
336 
current_streamOptionalHIPStreamGuardMasqueradingAsCUDA337   std::optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
338     auto r = guard_.current_stream();
339     if (r.has_value()) {
340       return std::make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
341     } else {
342       return std::nullopt;
343     }
344   }
345 
resetOptionalHIPStreamGuardMasqueradingAsCUDA346   void reset() { guard_.reset(); }
347 
348 private:
349   c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
350 };
351 
352 struct HIPMultiStreamGuardMasqueradingAsCUDA {
HIPMultiStreamGuardMasqueradingAsCUDAHIPMultiStreamGuardMasqueradingAsCUDA353   explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams)
354     : guard_(unwrapStreams(streams)) {}
355 
356   HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
357   HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
358   HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
359   HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
360 
361 private:
362   c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
363 
unwrapStreamsHIPMultiStreamGuardMasqueradingAsCUDA364   static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) {
365     std::vector<Stream> streams;
366     streams.reserve(hipStreams.size());
367     for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
368       streams.push_back(hipStream);
369     }
370     return streams;
371   }
372 };
373 
374 }} // namespace c10::hip
375