1 #pragma once 2 3 // This file provides implementations of InlineDeviceGuard and 4 // InlineOptionalDeviceGuard. 5 6 #include <c10/core/Device.h> 7 #include <c10/core/DeviceType.h> 8 #include <c10/core/impl/DeviceGuardImplInterface.h> 9 #include <c10/core/impl/VirtualGuardImpl.h> 10 #include <c10/util/Exception.h> 11 #include <c10/util/Optional.h> 12 #include <type_traits> 13 #include <utility> 14 15 namespace c10::impl { 16 17 /** 18 * A DeviceGuard is an RAII class that sets a device to some value 19 * on construction, and resets the device to its original value on 20 * destruction. 21 * 22 * InlineDeviceGuard is a helper class for implementing DeviceGuards. 23 * It is templated over a DeviceGuardImpl (anything that implements 24 * DeviceGuardImplInterface). There are two primary ways to instantiate 25 * InlineDeviceGuard: 26 * 27 * - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl. 28 * This is the best way to use InlineDeviceGuard, as all calls are 29 * devirtualized, giving you code as efficient as straight line 30 * calls to cudaGetDevice/cudaSetDevice. 31 * 32 * - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl 33 * retrieved from a DeviceType registry. We have explicitly instantiated 34 * InlineDeviceGuard this way as c10::DeviceGuard. 35 * 36 * If you are in a hurry, you can use InlineDeviceGuard directly: 37 * 38 * using CUDAGuard = impl::InlineDeviceGuard<CUDAGuardImpl>; 39 * 40 * However, you can provide a better user experience if you explicitly write a 41 * wrapper class that itself contains the template instantiation: 42 * 43 * class CUDAGuard { 44 * public: 45 * // ... the API ... 46 * private: 47 * impl::InlineDeviceGuard<CUDAGuardImpl> guard_; 48 * } 49 * 50 * The wrapper class provides a good place to write documentation, and helps 51 * avoid weird template instantiation errors when a user incorrectly uses the 52 * class. 53 * 54 * If you need to test this class, consider instantiating it with FakeGuardImpl. 55 */ 56 template <typename T> 57 class InlineDeviceGuard { 58 public: 59 // Note [Omitted default constructor from RAII] 60 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 61 // In principle, we could add a default constructor to 62 // DeviceGuard which reads the current device and promises to 63 // restore to that device on exit. However, most cases where you 64 // would have written this, you probably meant to actually just 65 // use OptionalDeviceGuard (since you don't actually need the 66 // restore to happen if you don't ever actually set the device). 67 // We remove the constructor here to encourage you to think about 68 // what you actually want to happen. 69 explicit InlineDeviceGuard() = delete; 70 71 /// Set the current device to the passed Device. InlineDeviceGuard(Device device)72 explicit InlineDeviceGuard(Device device) 73 : impl_(device.type()), 74 original_device_( 75 device.index() == -1 ? impl_.getDevice() 76 : impl_.exchangeDevice(device)), 77 current_device_(device.index() == -1 ? original_device_ : device) {} 78 79 /// Set the current device index to the passed DeviceIndex. (The 80 /// device type is inferred from the template parameter T). 81 template < 82 typename U = T, 83 typename = 84 typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>> InlineDeviceGuard(DeviceIndex device_index)85 explicit InlineDeviceGuard(DeviceIndex device_index) 86 : InlineDeviceGuard(Device(U::static_type, device_index)) {} 87 88 /// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit 89 /// DeviceGuardImplInterface pointer. 90 template < 91 typename U = T, 92 typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>> InlineDeviceGuard(Device device,const DeviceGuardImplInterface * impl)93 explicit InlineDeviceGuard( 94 Device device, 95 const DeviceGuardImplInterface* impl) 96 : impl_( 97 VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))), 98 original_device_( 99 device.index() == -1 ? impl_.getDevice() 100 : impl_.exchangeDevice(device)), 101 current_device_(device.index() == -1 ? original_device_ : device) {} 102 103 /// Copy is disallowed 104 InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete; 105 InlineDeviceGuard<T>& operator=(const InlineDeviceGuard<T>&) = delete; 106 107 /// Move is disallowed, as DeviceGuard does not have an uninitialized state, 108 /// which is required for moves on types with nontrivial destructors. 109 InlineDeviceGuard(InlineDeviceGuard<T>&& other) = delete; 110 InlineDeviceGuard& operator=(InlineDeviceGuard<T>&& other) = delete; 111 ~InlineDeviceGuard()112 ~InlineDeviceGuard() { 113 impl_.uncheckedSetDevice(original_device_); 114 } 115 116 /// Sets the device to the given one. 117 template < 118 typename U = T, 119 typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>, int> = 0> set_device(at::Device device)120 void set_device(at::Device device) { 121 AT_ASSERT( 122 (U::static_type == DeviceType::HIP && device.is_cuda()) || 123 device.type() == U::static_type); 124 auto index = device.index(); 125 if (index == -1) 126 return; 127 impl_.setDevice(device); 128 current_device_ = device; 129 } 130 131 /// Resets the currently set device to its original device, and then sets the 132 /// current device to the passed device. This is effectively equivalent to 133 /// set_device when a guard supports only a single device type. 134 template <typename U = T> reset_device(at::Device device)135 typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>> reset_device( 136 at::Device device) { 137 set_device(device); 138 } 139 140 /// Resets the currently set device to its original device, and then sets the 141 /// current device to the passed device (for a possibly different device 142 /// type). 143 /// 144 /// This method is named reset_device to highlight the fact that previous 145 /// device settings from this guard are NOT preserved, even if the device 146 /// has a different device type. For example: 147 /// 148 /// // CUDA device is 0 149 /// DeviceGuard g(Device(kCUDA, 1)); 150 /// g.reset_device(Device(kHIP, 2)); 151 /// // CUDA device is 0 (!!) 152 /// 153 /// NOTE: this implementation may skip some device setting if it can prove 154 /// that it is unnecessary. 155 /// 156 /// Optional argument is for testing only. 157 template <typename U = T> 158 typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>> reset_device( 159 at::Device device, 160 const impl::DeviceGuardImplInterface* impl = nullptr) { 161 auto index = device.index(); 162 if (index == -1) 163 return; 164 if (device.type() == original_device_.type()) { 165 AT_ASSERT(impl == nullptr || impl->type() == device.type()); 166 impl_.setDevice(device); 167 current_device_ = device; 168 } else { 169 // Destruct and reconstruct the DeviceGuard in place 170 impl_.setDevice(original_device_); 171 impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl); 172 original_device_ = impl_.exchangeDevice(device); 173 current_device_ = device; 174 } 175 } 176 177 /// Sets the device index to the given one. The device type is inferred 178 /// from the original device type. set_index(DeviceIndex index)179 void set_index(DeviceIndex index) { 180 reset_device(Device(original_device_.type(), index)); 181 } 182 183 /// Returns the device that was set at the time the most recent 184 /// reset_device(), or otherwise the device at construction time. original_device()185 Device original_device() const { 186 return original_device_; 187 } 188 189 /// Returns the most recent device that was set using this device guard, 190 /// either from construction, or via set_device/reset_device/set_index. current_device()191 Device current_device() const { 192 return current_device_; 193 } 194 195 protected: 196 T impl_; 197 198 private: 199 Device original_device_; 200 Device current_device_; 201 }; 202 203 /** 204 * A OptionalDeviceGuard is an RAII class that sets a device to some value on 205 * initialization, and resets the device to its original value on destruction. 206 * 207 * InlineOptionalDeviceGuard is a helper class for implementing 208 * OptionalDeviceGuards. See guidance in InlineDeviceGuard on how to 209 * use this. See OptionalDeviceGuard for user-oriented usage notes. 210 */ 211 template <typename T> 212 class InlineOptionalDeviceGuard { 213 public: 214 // Note [Explicit initialization of optional fields] 215 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 216 // Explicit initialization of optional fields 217 // required to workaround an nvcc bug; see 218 // https://github.com/pytorch/pytorch/issues/12117 219 220 /// Creates an uninitialized OptionalDeviceGuard. InlineOptionalDeviceGuard()221 explicit InlineOptionalDeviceGuard() 222 : guard_() // See Note [Explicit initialization of optional fields] 223 {} 224 225 /// Set the current device to the passed Device, if it is not nullopt. InlineOptionalDeviceGuard(std::optional<Device> device_opt)226 explicit InlineOptionalDeviceGuard(std::optional<Device> device_opt) 227 : guard_() { // See Note [Explicit initialization of optional fields] 228 if (device_opt.has_value()) { 229 guard_.emplace(device_opt.value()); 230 } 231 } 232 233 /// Set the current device to the passed DeviceIndex, if it is not nullopt. 234 template < 235 typename U = T, 236 typename = 237 typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>> InlineOptionalDeviceGuard(std::optional<DeviceIndex> device_index_opt)238 explicit InlineOptionalDeviceGuard( 239 std::optional<DeviceIndex> device_index_opt) 240 : guard_() { // See Note [Explicit initialization of optional fields] 241 if (device_index_opt.has_value()) { 242 guard_.emplace(device_index_opt.value()); 243 } 244 } 245 246 /// All constructors of DeviceGuard are valid for OptionalDeviceGuard 247 /// and result in initialized OptionalDeviceGuard. 248 template <typename... Args> InlineOptionalDeviceGuard(Args &&...args)249 explicit InlineOptionalDeviceGuard(Args&&... args) 250 : guard_(std::in_place, std::forward<Args>(args)...) {} 251 252 // TODO: Consider reading Tensor and TensorList constructors here, when 253 // Tensor moves to c10. (These are only valid on OptionalDeviceGuard, 254 // because a Tensor may be undefined, in which case we need an uninitialized 255 // tensor guard.) 256 257 // Note [Move construction for RAII guards is tricky] 258 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 259 // In principle, move construction is useful for terminating 260 // the lifetime of a `OptionalDeviceGuard` early; for example: 261 // 262 // // current device is d0 263 // OptionalDeviceGuard g1(d1); 264 // // current device is d1 265 // { 266 // OptionalDeviceGuard g2(std::move(g1)); 267 // } 268 // // current device is d0!! 269 // 270 // However, it's difficult to implement the move constructor 271 // in a way that works in all situations. For example, consider 272 // the following example: 273 // 274 // OptionalDeviceGuard g1(d1); 275 // { 276 // OptionalDeviceGuard g2(d2); 277 // { 278 // OptionalDeviceGuard g3(std::move(g1)); // !!! 279 // } 280 // } 281 // 282 // What should the current device be while g3 in scope... and what 283 // should it be after it goes out of scope? What about g2? 284 // There don't seem to be satisfactory answers for these questions. 285 // 286 // It's in principle possible to raise an error when this occurs 287 // by doing some extra thread-local bookkeeping. But why bother? 288 // Just don't provide the constructor. 289 InlineOptionalDeviceGuard(InlineOptionalDeviceGuard<T>&& other) = delete; 290 291 // Note [Move assignment for RAII guards is tricky] 292 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 293 // Move assignment is deleted, because you need to know which guard was 294 // defined "first", as that guard's original_device_ wins--with the current 295 // representation, we have no way of telling which is the case. (Move 296 // construction does not have this problem, as one guard is always 297 // uninitialized.) 298 // 299 // We can make this clear by way of a pair of examples: 300 // 301 // Example 1: 302 // 303 // // initial device is n0 304 // { 305 // CUDAGuard g1(n1); 306 // { 307 // CUDAGuard g2(n2); 308 // // current device should be n2 309 // g1 = std::move(g2); 310 // // current device should still be n2 311 // } 312 // // current device should still be n2 313 // } 314 // // current device should be n0 315 // 316 // Example 2 (flip the order of the two guards): 317 // 318 // // initial device is n0 319 // { 320 // CUDAGuard g2(n2); 321 // { 322 // CUDAGuard g1(n1); 323 // // current device should be n1 324 // g1 = std::move(g2); 325 // // current device should be n2 326 // } 327 // // current device should be n0 (since g2 has been vacated) 328 // } 329 // 330 // In both examples, we need g1 to restore to n0 after move assignment. 331 // However, in example 1, this is determined by the restore value of g1 332 // (prior to the move). In example 2, however, it is determined by the the 333 // restore value of g2(!!). We don't know which one should win, without having 334 // a way of telling which guard was allocated first. 335 // 336 // We could solve this with an extra thread-local variable. But no one is 337 // actually using move-assignment. So just get rid of it. 338 InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = 339 delete; 340 341 /// Sets the device to the given one. Initializes OptionalDeviceGuard if it 342 /// is not already initialized. 343 template < 344 typename U = T, 345 typename = 346 typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>> set_device(at::Device device)347 void set_device(at::Device device) { 348 if (!guard_.has_value()) { 349 guard_.emplace(device); 350 } else { 351 guard_->set_device(device); 352 } 353 } 354 355 /// Resets the currently set device to its original device, and then sets the 356 /// current device to the passed device (for a possibly different device 357 /// type). Initializes OptionalDeviceGuard if it is not already initialized. 358 /// 359 /// See notes on why this is called reset_device on InlineDeviceGuard. 360 /// 361 /// Optional argument is for testing only. 362 template < 363 typename U = T, 364 typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>> 365 void reset_device( 366 at::Device device, 367 const DeviceGuardImplInterface* impl = nullptr) { 368 if (!guard_.has_value()) { 369 guard_.emplace(device, impl); 370 } else { 371 guard_->reset_device(device, impl); 372 } 373 } 374 375 /// Resets the currently set device to its original device, and then sets the 376 /// current device to the passed device. Initializes the guard if it is 377 /// not already initialized. This is effectively equivalent to set_device 378 /// when a guard supports only a single device type. 379 template < 380 typename U = T, 381 typename = 382 typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>> reset_device(at::Device device)383 void reset_device(at::Device device) { 384 if (!guard_.has_value()) { 385 guard_.emplace(device); 386 } else { 387 guard_->reset_device(device); 388 } 389 } 390 391 /// Sets the device index to the given one. The device type is statically 392 /// known. 393 template < 394 typename U = T, 395 typename = 396 typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>> set_index(DeviceIndex index)397 void set_index(DeviceIndex index) { 398 if (!guard_.has_value()) { 399 guard_.emplace(index); 400 } else { 401 guard_->set_index(index); 402 } 403 } 404 405 /// Returns the device that was set immediately prior to initialization of 406 /// the, guard, or nullopt if the guard is uninitialized. original_device()407 std::optional<Device> original_device() const { 408 return guard_.has_value() ? std::make_optional(guard_->original_device()) 409 : std::nullopt; 410 } 411 412 /// Returns the most recent device that was set using this device guard, 413 /// either from construction, or via set_device, if the guard is initialized, 414 /// or nullopt if the guard is uninitialized. current_device()415 std::optional<Device> current_device() const { 416 return guard_.has_value() ? std::make_optional(guard_->current_device()) 417 : std::nullopt; 418 } 419 420 /// Restore the original device, resetting this guard to uninitialized state. reset()421 void reset() { 422 guard_.reset(); 423 } 424 425 private: 426 std::optional<InlineDeviceGuard<T>> guard_; 427 }; 428 429 } // namespace c10::impl 430