1 #pragma once 2 3 #include <c10/core/Device.h> 4 #include <c10/core/impl/DeviceGuardImplInterface.h> 5 #include <c10/core/impl/InlineDeviceGuard.h> 6 #include <c10/core/impl/VirtualGuardImpl.h> 7 #include <c10/util/Optional.h> 8 9 namespace c10 { 10 11 /// RAII guard that sets a certain default device in its constructor, and 12 /// changes it back to the device that was originally active upon destruction. 13 /// 14 /// The device is always reset to the one that was active at the time of 15 /// construction of the guard. Even if you `set_device` after construction, the 16 /// destructor will still reset the device to the one that was active at 17 /// construction time. 18 /// 19 /// This device guard does NOT have an uninitialized state; it is guaranteed 20 /// to reset a device on exit. If you are in a situation where you *might* 21 /// want to setup a guard (i.e., are looking for the moral equivalent 22 /// of std::optional<DeviceGuard>), see OptionalDeviceGuard. 23 class DeviceGuard { 24 public: 25 /// No default constructor; see Note [Omitted default constructor from RAII] 26 explicit DeviceGuard() = delete; 27 28 /// Set the current device to the passed Device. DeviceGuard(Device device)29 explicit DeviceGuard(Device device) : guard_(device) {} 30 31 /// This constructor is for testing only. DeviceGuard(Device device,const impl::DeviceGuardImplInterface * impl)32 explicit DeviceGuard( 33 Device device, 34 const impl::DeviceGuardImplInterface* impl) 35 : guard_(device, impl) {} 36 37 /// Copy is disallowed 38 DeviceGuard(const DeviceGuard&) = delete; 39 DeviceGuard& operator=(const DeviceGuard&) = delete; 40 41 /// Move is disallowed, as DeviceGuard does not have an uninitialized state, 42 /// which is required for moves on types with nontrivial destructors. 43 DeviceGuard(DeviceGuard&& other) = delete; 44 DeviceGuard& operator=(DeviceGuard&& other) = delete; 45 46 /// Sets the device to the given one. The specified device must be consistent 47 /// with the device type originally specified during guard construction. 48 /// 49 /// TODO: The consistency check here is inconsistent with StreamGuard's 50 /// behavior with set_stream, where a stream on a different device than 51 /// the original one isn't an error; we just reset the stream and then 52 /// switch devices. reset_device(at::Device device)53 void reset_device(at::Device device) { 54 guard_.reset_device(device); 55 } 56 57 /// This method is for testing only. reset_device(at::Device device,const impl::DeviceGuardImplInterface * impl)58 void reset_device( 59 at::Device device, 60 const impl::DeviceGuardImplInterface* impl) { 61 guard_.reset_device(device, impl); 62 } 63 64 /// Sets the device index to the given one. The device type is inferred 65 /// from the original device type the guard was constructed with. set_index(DeviceIndex index)66 void set_index(DeviceIndex index) { 67 guard_.set_index(index); 68 } 69 70 /// Returns the device that was set at the time the guard was constructed. original_device()71 Device original_device() const { 72 return guard_.original_device(); 73 } 74 75 /// Returns the most recent device that was set using this device guard, 76 /// either from construction, or via set_device. current_device()77 Device current_device() const { 78 return guard_.current_device(); 79 } 80 81 private: 82 impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_; 83 }; 84 85 /** 86 * A OptionalDeviceGuard is an RAII class that sets a device to some value on 87 * initialization, and resets the device to its original value on destruction. 88 * Morally, a OptionalDeviceGuard is equivalent to std::optional<DeviceGuard>, 89 * but with extra constructors and methods as appropriate. 90 * 91 * Besides its obvious use (optionally applying a DeviceGuard), 92 * OptionalDeviceGuard is often also used for the following idiom: 93 * 94 * OptionalDeviceGuard g; 95 * for (const auto& t : tensors) { 96 * g.set_device(t.device()); 97 * do_something_with(t); 98 * } 99 * 100 * This usage is marginally more efficient than constructing a DeviceGuard every 101 * iteration of the for loop, as it avoids an unnecessary device reset. 102 * 103 * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs 104 * when you use the nullary constructor, or pass a nullopt to the constructor. 105 * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the 106 * original device was and they do not reset on destruction. This is why 107 * original_device() and current_device() return std::optional<Device> rather 108 * than Device (as they do in DeviceGuard), and also is why we didn't just 109 * provide OptionalDeviceGuard by default and hide DeviceGuard from users. 110 * 111 * The semantics of an OptionalDeviceGuard are exactly explained by thinking 112 * of it as an std::optional<DeviceGuard>. In particular, an initialized 113 * OptionalDeviceGuard doesn't restore device to its value at construction; it 114 * restores device to its value *at initialization*. So if you have the 115 * program: 116 * 117 * setDevice(1); 118 * OptionalDeviceGuard g; 119 * setDevice(2); 120 * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes! 121 * 122 * On destruction, g will reset device to 2, rather than 1. 123 * 124 * An uninitialized OptionalDeviceGuard is distinct from a (initialized) 125 * DeviceGuard whose original_device_ and current_device_ match, since the 126 * DeviceGuard will still reset the device to original_device_. 127 */ 128 class OptionalDeviceGuard { 129 public: 130 /// Create an uninitialized guard. Set the guard later using reset_device. 131 explicit OptionalDeviceGuard() = default; 132 133 /// Initialize the guard, setting the current device to the passed Device. OptionalDeviceGuard(Device device)134 explicit OptionalDeviceGuard(Device device) : guard_(device) {} 135 136 /// Initialize the guard if a Device is passed; otherwise leave the 137 /// guard uninitialized. OptionalDeviceGuard(std::optional<Device> device)138 explicit OptionalDeviceGuard(std::optional<Device> device) : guard_(device) {} 139 140 /// Constructor for testing only. OptionalDeviceGuard(Device device,const impl::DeviceGuardImplInterface * impl)141 explicit OptionalDeviceGuard( 142 Device device, 143 const impl::DeviceGuardImplInterface* impl) 144 : guard_(device, impl) {} 145 146 /// Copy is disallowed 147 OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; 148 OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; 149 150 /// Move is disallowed 151 /// See Note [Explicit initialization of optional fields] 152 /// and // Note [Move construction for RAII guards is tricky] 153 /// for rationale. 154 OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete; 155 OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete; 156 157 /// Sets the device to the given one. The specified device must be consistent 158 /// with the device type originally specified during guard construction. reset_device(at::Device device)159 void reset_device(at::Device device) { 160 guard_.reset_device(device); 161 } 162 163 /// For testing only reset_device(at::Device device,const impl::DeviceGuardImplInterface * impl)164 void reset_device( 165 at::Device device, 166 const impl::DeviceGuardImplInterface* impl) { 167 guard_.reset_device(device, impl); 168 } 169 170 /// Returns the device that was set at the time the guard was constructed. original_device()171 std::optional<Device> original_device() const { 172 return guard_.original_device(); 173 } 174 175 /// Returns the most recent device that was set using this device guard, 176 /// either from construction, or via reset_device. current_device()177 std::optional<Device> current_device() const { 178 return guard_.current_device(); 179 } 180 181 private: 182 impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{}; 183 }; 184 185 // Note [Whither the DeviceGuard boilerplate] 186 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 187 // Design note: in principle, we could avoid these wrappers using: 188 // 189 // using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>; 190 // using OptionalDeviceGuard = 191 // impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>; 192 // 193 // But the error messages are worse, and our users can't just look at the 194 // header file to find out what's going on. Furthermore, for specializations 195 // like CUDAStreamGuard, it can be profitable to replace some interfaces with 196 // refined types (e.g., return CUDAStream instead of Stream). So, we eat 197 // the boilerplate and write out the API explicitly. 198 199 } // namespace c10 200