xref: /aosp_15_r20/external/pytorch/c10/core/DeviceGuard.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/impl/DeviceGuardImplInterface.h>
5 #include <c10/core/impl/InlineDeviceGuard.h>
6 #include <c10/core/impl/VirtualGuardImpl.h>
7 #include <c10/util/Optional.h>
8 
9 namespace c10 {
10 
11 /// RAII guard that sets a certain default device in its constructor, and
12 /// changes it back to the device that was originally active upon destruction.
13 ///
14 /// The device is always reset to the one that was active at the time of
15 /// construction of the guard. Even if you `set_device` after construction, the
16 /// destructor will still reset the device to the one that was active at
17 /// construction time.
18 ///
19 /// This device guard does NOT have an uninitialized state; it is guaranteed
20 /// to reset a device on exit.  If you are in a situation where you *might*
21 /// want to setup a guard (i.e., are looking for the moral equivalent
22 /// of std::optional<DeviceGuard>), see OptionalDeviceGuard.
23 class DeviceGuard {
24  public:
25   /// No default constructor; see Note [Omitted default constructor from RAII]
26   explicit DeviceGuard() = delete;
27 
28   /// Set the current device to the passed Device.
DeviceGuard(Device device)29   explicit DeviceGuard(Device device) : guard_(device) {}
30 
31   /// This constructor is for testing only.
DeviceGuard(Device device,const impl::DeviceGuardImplInterface * impl)32   explicit DeviceGuard(
33       Device device,
34       const impl::DeviceGuardImplInterface* impl)
35       : guard_(device, impl) {}
36 
37   /// Copy is disallowed
38   DeviceGuard(const DeviceGuard&) = delete;
39   DeviceGuard& operator=(const DeviceGuard&) = delete;
40 
41   /// Move is disallowed, as DeviceGuard does not have an uninitialized state,
42   /// which is required for moves on types with nontrivial destructors.
43   DeviceGuard(DeviceGuard&& other) = delete;
44   DeviceGuard& operator=(DeviceGuard&& other) = delete;
45 
46   /// Sets the device to the given one.  The specified device must be consistent
47   /// with the device type originally specified during guard construction.
48   ///
49   /// TODO: The consistency check here is inconsistent with StreamGuard's
50   /// behavior with set_stream, where a stream on a different device than
51   /// the original one isn't an error; we just reset the stream and then
52   /// switch devices.
reset_device(at::Device device)53   void reset_device(at::Device device) {
54     guard_.reset_device(device);
55   }
56 
57   /// This method is for testing only.
reset_device(at::Device device,const impl::DeviceGuardImplInterface * impl)58   void reset_device(
59       at::Device device,
60       const impl::DeviceGuardImplInterface* impl) {
61     guard_.reset_device(device, impl);
62   }
63 
64   /// Sets the device index to the given one.  The device type is inferred
65   /// from the original device type the guard was constructed with.
set_index(DeviceIndex index)66   void set_index(DeviceIndex index) {
67     guard_.set_index(index);
68   }
69 
70   /// Returns the device that was set at the time the guard was constructed.
original_device()71   Device original_device() const {
72     return guard_.original_device();
73   }
74 
75   /// Returns the most recent device that was set using this device guard,
76   /// either from construction, or via set_device.
current_device()77   Device current_device() const {
78     return guard_.current_device();
79   }
80 
81  private:
82   impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
83 };
84 
85 /**
86  * A OptionalDeviceGuard is an RAII class that sets a device to some value on
87  * initialization, and resets the device to its original value on destruction.
88  * Morally, a OptionalDeviceGuard is equivalent to std::optional<DeviceGuard>,
89  * but with extra constructors and methods as appropriate.
90  *
91  * Besides its obvious use (optionally applying a DeviceGuard),
92  * OptionalDeviceGuard is often also used for the following idiom:
93  *
94  *    OptionalDeviceGuard g;
95  *    for (const auto& t : tensors) {
96  *      g.set_device(t.device());
97  *      do_something_with(t);
98  *    }
99  *
100  * This usage is marginally more efficient than constructing a DeviceGuard every
101  * iteration of the for loop, as it avoids an unnecessary device reset.
102  *
103  * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized.  This occurs
104  * when you use the nullary constructor, or pass a nullopt to the constructor.
105  * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
106  * original device was and they do not reset on destruction.  This is why
107  * original_device() and current_device() return std::optional<Device> rather
108  * than Device (as they do in DeviceGuard), and also is why we didn't just
109  * provide OptionalDeviceGuard by default and hide DeviceGuard from users.
110  *
111  * The semantics of an OptionalDeviceGuard are exactly explained by thinking
112  * of it as an std::optional<DeviceGuard>.  In particular, an initialized
113  * OptionalDeviceGuard doesn't restore device to its value at construction; it
114  * restores device to its value *at initialization*.  So if you have the
115  * program:
116  *
117  *     setDevice(1);
118  *     OptionalDeviceGuard g;
119  *     setDevice(2);
120  *     g.reset_device(Device(DeviceType::CUDA, 3));  // initializes!
121  *
122  * On destruction, g will reset device to 2, rather than 1.
123  *
124  * An uninitialized OptionalDeviceGuard is distinct from a (initialized)
125  * DeviceGuard whose original_device_ and current_device_ match, since the
126  * DeviceGuard will still reset the device to original_device_.
127  */
128 class OptionalDeviceGuard {
129  public:
130   /// Create an uninitialized guard.  Set the guard later using reset_device.
131   explicit OptionalDeviceGuard() = default;
132 
133   /// Initialize the guard, setting the current device to the passed Device.
OptionalDeviceGuard(Device device)134   explicit OptionalDeviceGuard(Device device) : guard_(device) {}
135 
136   /// Initialize the guard if a Device is passed; otherwise leave the
137   /// guard uninitialized.
OptionalDeviceGuard(std::optional<Device> device)138   explicit OptionalDeviceGuard(std::optional<Device> device) : guard_(device) {}
139 
140   /// Constructor for testing only.
OptionalDeviceGuard(Device device,const impl::DeviceGuardImplInterface * impl)141   explicit OptionalDeviceGuard(
142       Device device,
143       const impl::DeviceGuardImplInterface* impl)
144       : guard_(device, impl) {}
145 
146   /// Copy is disallowed
147   OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
148   OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;
149 
150   /// Move is disallowed
151   /// See Note [Explicit initialization of optional fields]
152   /// and // Note [Move construction for RAII guards is tricky]
153   /// for rationale.
154   OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
155   OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;
156 
157   /// Sets the device to the given one.  The specified device must be consistent
158   /// with the device type originally specified during guard construction.
reset_device(at::Device device)159   void reset_device(at::Device device) {
160     guard_.reset_device(device);
161   }
162 
163   /// For testing only
reset_device(at::Device device,const impl::DeviceGuardImplInterface * impl)164   void reset_device(
165       at::Device device,
166       const impl::DeviceGuardImplInterface* impl) {
167     guard_.reset_device(device, impl);
168   }
169 
170   /// Returns the device that was set at the time the guard was constructed.
original_device()171   std::optional<Device> original_device() const {
172     return guard_.original_device();
173   }
174 
175   /// Returns the most recent device that was set using this device guard,
176   /// either from construction, or via reset_device.
current_device()177   std::optional<Device> current_device() const {
178     return guard_.current_device();
179   }
180 
181  private:
182   impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{};
183 };
184 
185 // Note [Whither the DeviceGuard boilerplate]
186 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
187 // Design note: in principle, we could avoid these wrappers using:
188 //
189 // using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
190 // using OptionalDeviceGuard =
191 // impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
192 //
193 // But the error messages are worse, and our users can't just look at the
194 // header file to find out what's going on.  Furthermore, for specializations
195 // like CUDAStreamGuard, it can be profitable to replace some interfaces with
196 // refined types (e.g., return CUDAStream instead of Stream).  So, we eat
197 // the boilerplate and write out the API explicitly.
198 
199 } // namespace c10
200