1 #pragma once 2 3 #include <ATen/hip/HIPConfig.h> 4 5 // The includes of HIPGuard.h 6 #include <c10/hip/impl/HIPGuardImpl.h> 7 #include <c10/hip/HIPMacros.h> 8 #include <c10/core/DeviceType.h> 9 #include <c10/core/impl/InlineDeviceGuard.h> 10 #include <c10/core/impl/InlineStreamGuard.h> 11 #include <c10/util/Exception.h> 12 13 #include <c10/hip/impl/HIPGuardImpl.h> 14 15 #include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h> 16 #include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h> 17 18 // Use of c10::hip namespace here makes hipification easier, because 19 // I don't have to also fix namespaces. Sorry! 20 namespace c10 { namespace hip { 21 22 // Note [Masquerading as CUDA] 23 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 // c10_hip is very easy to understand: it is HIPified from c10_cuda, 25 // and anywhere you said CUDA, the source code now says HIP. HIPified 26 // PyTorch is much harder to understand: it is HIPified from regular 27 // PyTorch, yes, but NO source-to-source translation from CUDA to 28 // HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP". 29 // For example, when you use HIPified PyTorch, you say x.cuda() to 30 // move a tensor onto ROCm device. We call this situation "HIP 31 // masquerading as CUDA". 32 // 33 // This leads to a very awkward situation when we want to call c10_hip 34 // code from PyTorch, since c10_hip is expecting things to be called 35 // HIP, but PyTorch is calling them CUDA (masquerading as HIP). To 36 // fix this impedance mismatch, we have MasqueradingAsCUDA variants 37 // for all c10_hip classes. These translate between the "HIP" and "CUDA 38 // masquerading as HIP" worlds. For example, 39 // HIPGuardImplMasqueradingAsCUDA (this file) provides something like a 40 // HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type() 41 // returns CUDA, getDevice() reports the current HIP device as a CUDA 42 // device.) 43 // 44 // We should be able to delete all of these classes entirely once 45 // we switch PyTorch to calling a HIP a HIP. 46 // 47 // When you add a new MasqueradingAsCUDA class/function, you need to 48 // also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py 49 // 50 // 51 // 52 // By the way, note that the cpp file associated with this also 53 // *overwrites* the entry in the DeviceGuardImpl registry for CUDA with 54 // this HIP implementation. 55 56 struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface { 57 static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA; HIPGuardImplMasqueradingAsCUDAfinal58 HIPGuardImplMasqueradingAsCUDA() {} HIPGuardImplMasqueradingAsCUDAfinal59 HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) { 60 TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA); 61 } typefinal62 c10::DeviceType type() const override { 63 return c10::DeviceType::CUDA; 64 } exchangeDevicefinal65 Device exchangeDevice(Device d) const override { 66 TORCH_INTERNAL_ASSERT(d.is_cuda()); 67 Device old_device = getDevice(); 68 if (old_device.index() != d.index()) { 69 C10_HIP_CHECK(hipSetDevice(d.index())); 70 } 71 return old_device; 72 } getDevicefinal73 Device getDevice() const override { 74 int device; 75 C10_HIP_CHECK(hipGetDevice(&device)); 76 return Device(c10::DeviceType::CUDA, device); 77 } setDevicefinal78 void setDevice(Device d) const override { 79 TORCH_INTERNAL_ASSERT(d.is_cuda()); 80 C10_HIP_CHECK(hipSetDevice(d.index())); 81 } uncheckedSetDevicefinal82 void uncheckedSetDevice(Device d) const noexcept override { 83 C10_HIP_CHECK_WARN(hipSetDevice(d.index())); 84 } getStreamfinal85 Stream getStream(Device d) const noexcept override { 86 return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap(); 87 } getDefaultStreamfinal88 Stream getDefaultStream(Device d) const override { 89 return getDefaultHIPStreamMasqueradingAsCUDA(d.index()); 90 } 91 Stream getNewStream(Device d, int priority = 0) const override { 92 return getStreamFromPoolMasqueradingAsCUDA(priority, d.index()); 93 } 94 Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override { 95 return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index()); 96 } exchangeStreamfinal97 Stream exchangeStream(Stream s) const noexcept override { 98 HIPStreamMasqueradingAsCUDA cs(s); 99 auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index()); 100 setCurrentHIPStreamMasqueradingAsCUDA(cs); 101 return old_stream.unwrap(); 102 } deviceCountfinal103 DeviceIndex deviceCount() const noexcept override { 104 int deviceCnt; 105 hipError_t _err; 106 _err = hipGetDeviceCount(&deviceCnt); 107 if(_err != hipErrorNoDevice && _err != hipSuccess) 108 C10_HIP_CHECK(_err); 109 return deviceCnt; 110 } 111 112 // Event-related functions 113 // Note: hipEventCreateWithFlags should be called on the same device as 114 // the recording stream's device. createEventfinal115 void createEvent( 116 hipEvent_t* hip_event, 117 const EventFlag flag) const { 118 // Maps PyTorch's Event::Flag to HIP flag 119 auto hip_flag = hipEventDefault; 120 switch (flag) { 121 case EventFlag::PYTORCH_DEFAULT: 122 hip_flag = hipEventDisableTiming; 123 break; 124 case EventFlag::BACKEND_DEFAULT: 125 hip_flag = hipEventDefault; 126 break; 127 default: 128 TORCH_CHECK(false, "HIP event received unknown flag"); 129 } 130 131 C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag)); 132 } 133 destroyEventfinal134 void destroyEvent( 135 void* event, 136 const DeviceIndex device_index) const noexcept override { 137 if (!event) return; 138 auto hip_event = static_cast<hipEvent_t>(event); 139 int orig_device; 140 C10_HIP_CHECK_WARN(hipGetDevice(&orig_device)); 141 C10_HIP_CHECK_WARN(hipSetDevice(device_index)); 142 C10_HIP_CHECK_WARN(hipEventDestroy(hip_event)); 143 C10_HIP_CHECK_WARN(hipSetDevice(orig_device)); 144 } 145 recordfinal146 void record(void** event, 147 const Stream& stream, 148 const DeviceIndex device_index, 149 const EventFlag flag) const override { 150 TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), 151 "Event device index ", 152 device_index, 153 " does not match recording stream's device index ", 154 stream.device_index(), 155 "."); 156 157 hipEvent_t hip_event = static_cast<hipEvent_t>(*event); 158 HIPStreamMasqueradingAsCUDA hip_stream{stream}; 159 160 // Moves to stream's device to record 161 const auto orig_device = getDevice(); 162 setDevice(stream.device()); 163 164 // Creates the event (lazily) 165 if (!hip_event) createEvent(&hip_event, flag); 166 C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream)); 167 // Makes the void* point to the (possibly just allocated) HIP event 168 *event = hip_event; 169 170 // Resets device 171 setDevice(orig_device); 172 } 173 blockfinal174 void block( 175 void* event, 176 const Stream& stream) const override { 177 if (!event) return; 178 hipEvent_t hip_event = static_cast<hipEvent_t>(event); 179 HIPStreamMasqueradingAsCUDA hip_stream{stream}; 180 const auto orig_device = getDevice(); 181 setDevice(stream.device()); 182 C10_HIP_CHECK(hipStreamWaitEvent( 183 hip_stream, 184 hip_event, 185 /*flags (must be zero)=*/ 0)); 186 setDevice(orig_device); 187 } 188 queryEventfinal189 bool queryEvent(void* event) const override { 190 if (!event) return true; 191 hipEvent_t hip_event = static_cast<hipEvent_t>(event); 192 const hipError_t err = hipEventQuery(hip_event); 193 if (err != hipErrorNotReady) C10_HIP_CHECK(err); 194 else { 195 // ignore and clear the error if not ready 196 (void)hipGetLastError(); 197 } 198 return (err == hipSuccess); 199 } 200 201 // Stream-related functions queryStreamfinal202 bool queryStream(const Stream& stream) const override { 203 HIPStreamMasqueradingAsCUDA hip_stream{stream}; 204 return hip_stream.query(); 205 } 206 synchronizeStreamfinal207 void synchronizeStream(const Stream& stream) const override { 208 HIPStreamMasqueradingAsCUDA hip_stream{stream}; 209 hip_stream.synchronize(); 210 } 211 synchronizeEventfinal212 void synchronizeEvent(void* event) const override { 213 if (!event) 214 return; 215 hipEvent_t hip_event = static_cast<hipEvent_t>(event); 216 C10_HIP_CHECK(hipEventSynchronize(hip_event)); 217 } 218 recordDataPtrOnStreamfinal219 void recordDataPtrOnStream( 220 const c10::DataPtr& data_ptr, 221 const Stream& stream) const override { 222 HIPStreamMasqueradingAsCUDA hip_stream{stream}; 223 HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream); 224 } 225 elapsedTimefinal226 double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) 227 const override { 228 TORCH_CHECK( 229 event1 && event2, 230 "Both events must be recorded before calculating elapsed time."); 231 int orig_device; 232 C10_HIP_CHECK(hipGetDevice(&orig_device)); 233 C10_HIP_CHECK(hipSetDevice(device_index)); 234 hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1); 235 hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2); 236 float time_ms = 0; 237 // raise hipErrorNotReady if either event is recorded but not yet completed 238 C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2)); 239 C10_HIP_CHECK(hipSetDevice(orig_device)); 240 return static_cast<double>(time_ms); 241 } 242 }; 243 244 // All of the guards which have HIPGuardImpl burned in need to also have 245 // variants using HIPGuardImplMasqueradingAsCUDA. 246 247 /// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with 248 /// the correct InlineDeviceGuard burned in. Sorry about the 249 /// copy-pasting. 250 251 struct HIPGuardMasqueradingAsCUDA { 252 explicit HIPGuardMasqueradingAsCUDA() = delete; HIPGuardMasqueradingAsCUDAHIPGuardMasqueradingAsCUDA253 explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {} HIPGuardMasqueradingAsCUDAHIPGuardMasqueradingAsCUDA254 explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {} 255 256 HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete; 257 HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete; 258 HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete; 259 HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete; 260 set_deviceHIPGuardMasqueradingAsCUDA261 void set_device(Device device) { guard_.set_device(device); } reset_deviceHIPGuardMasqueradingAsCUDA262 void reset_device(Device device) { guard_.reset_device(device); } set_indexHIPGuardMasqueradingAsCUDA263 void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } original_deviceHIPGuardMasqueradingAsCUDA264 Device original_device() const { return guard_.original_device(); } current_deviceHIPGuardMasqueradingAsCUDA265 Device current_device() const { return guard_.current_device(); } 266 267 private: 268 c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_; 269 }; 270 271 struct OptionalHIPGuardMasqueradingAsCUDA { OptionalHIPGuardMasqueradingAsCUDAOptionalHIPGuardMasqueradingAsCUDA272 explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {} OptionalHIPGuardMasqueradingAsCUDAOptionalHIPGuardMasqueradingAsCUDA273 explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<Device> device_opt) : guard_(device_opt) {} OptionalHIPGuardMasqueradingAsCUDAOptionalHIPGuardMasqueradingAsCUDA274 explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {} 275 276 OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; 277 OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; 278 OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; 279 OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; 280 set_deviceOptionalHIPGuardMasqueradingAsCUDA281 void set_device(Device device) { guard_.set_device(device); } reset_deviceOptionalHIPGuardMasqueradingAsCUDA282 void reset_device(Device device) { guard_.reset_device(device); } set_indexOptionalHIPGuardMasqueradingAsCUDA283 void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } original_deviceOptionalHIPGuardMasqueradingAsCUDA284 std::optional<Device> original_device() const { return guard_.original_device(); } current_deviceOptionalHIPGuardMasqueradingAsCUDA285 std::optional<Device> current_device() const { return guard_.current_device(); } resetOptionalHIPGuardMasqueradingAsCUDA286 void reset() { guard_.reset(); } 287 288 private: 289 c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_; 290 }; 291 292 struct HIPStreamGuardMasqueradingAsCUDA { 293 explicit HIPStreamGuardMasqueradingAsCUDA() = delete; HIPStreamGuardMasqueradingAsCUDAHIPStreamGuardMasqueradingAsCUDA294 explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} 295 HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete; 296 HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete; 297 HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; 298 HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; 299 reset_streamHIPStreamGuardMasqueradingAsCUDA300 void reset_stream(Stream stream) { guard_.reset_stream(stream); } 301 original_streamHIPStreamGuardMasqueradingAsCUDA302 HIPStreamMasqueradingAsCUDA original_stream() const { 303 return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream()); 304 } current_streamHIPStreamGuardMasqueradingAsCUDA305 HIPStreamMasqueradingAsCUDA current_stream() const { 306 return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream()); 307 } 308 current_deviceHIPStreamGuardMasqueradingAsCUDA309 Device current_device() const { return guard_.current_device(); } original_deviceHIPStreamGuardMasqueradingAsCUDA310 Device original_device() const { return guard_.original_device(); } 311 312 private: 313 c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_; 314 }; 315 316 struct OptionalHIPStreamGuardMasqueradingAsCUDA { OptionalHIPStreamGuardMasqueradingAsCUDAOptionalHIPStreamGuardMasqueradingAsCUDA317 explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {} OptionalHIPStreamGuardMasqueradingAsCUDAOptionalHIPStreamGuardMasqueradingAsCUDA318 explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} OptionalHIPStreamGuardMasqueradingAsCUDAOptionalHIPStreamGuardMasqueradingAsCUDA319 explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional<Stream> stream_opt) : guard_(stream_opt) {} 320 321 OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; 322 OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; 323 OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; 324 OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; 325 reset_streamOptionalHIPStreamGuardMasqueradingAsCUDA326 void reset_stream(Stream stream) { guard_.reset_stream(stream); } 327 original_streamOptionalHIPStreamGuardMasqueradingAsCUDA328 std::optional<HIPStreamMasqueradingAsCUDA> original_stream() const { 329 auto r = guard_.original_stream(); 330 if (r.has_value()) { 331 return std::make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value())); 332 } else { 333 return std::nullopt; 334 } 335 } 336 current_streamOptionalHIPStreamGuardMasqueradingAsCUDA337 std::optional<HIPStreamMasqueradingAsCUDA> current_stream() const { 338 auto r = guard_.current_stream(); 339 if (r.has_value()) { 340 return std::make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value())); 341 } else { 342 return std::nullopt; 343 } 344 } 345 resetOptionalHIPStreamGuardMasqueradingAsCUDA346 void reset() { guard_.reset(); } 347 348 private: 349 c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_; 350 }; 351 352 struct HIPMultiStreamGuardMasqueradingAsCUDA { HIPMultiStreamGuardMasqueradingAsCUDAHIPMultiStreamGuardMasqueradingAsCUDA353 explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams) 354 : guard_(unwrapStreams(streams)) {} 355 356 HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete; 357 HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete; 358 HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete; 359 HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete; 360 361 private: 362 c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_; 363 unwrapStreamsHIPMultiStreamGuardMasqueradingAsCUDA364 static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) { 365 std::vector<Stream> streams; 366 streams.reserve(hipStreams.size()); 367 for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) { 368 streams.push_back(hipStream); 369 } 370 return streams; 371 } 372 }; 373 374 }} // namespace c10::hip 375