xref: /aosp_15_r20/external/pytorch/c10/core/impl/InlineDeviceGuard.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // This file provides implementations of InlineDeviceGuard and
4 // InlineOptionalDeviceGuard.
5 
6 #include <c10/core/Device.h>
7 #include <c10/core/DeviceType.h>
8 #include <c10/core/impl/DeviceGuardImplInterface.h>
9 #include <c10/core/impl/VirtualGuardImpl.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/Optional.h>
12 #include <type_traits>
13 #include <utility>
14 
15 namespace c10::impl {
16 
17 /**
18  * A DeviceGuard is an RAII class that sets a device to some value
19  * on construction, and resets the device to its original value on
20  * destruction.
21  *
22  * InlineDeviceGuard is a helper class for implementing DeviceGuards.
23  * It is templated over a DeviceGuardImpl (anything that implements
24  * DeviceGuardImplInterface).  There are two primary ways to instantiate
25  * InlineDeviceGuard:
26  *
27  *  - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl.
28  *    This is the best way to use InlineDeviceGuard, as all calls are
29  *    devirtualized, giving you code as efficient as straight line
30  *    calls to cudaGetDevice/cudaSetDevice.
31  *
32  *  - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl
33  *    retrieved from a DeviceType registry.  We have explicitly instantiated
34  *    InlineDeviceGuard this way as c10::DeviceGuard.
35  *
36  * If you are in a hurry, you can use InlineDeviceGuard directly:
37  *
38  *    using CUDAGuard = impl::InlineDeviceGuard<CUDAGuardImpl>;
39  *
40  * However, you can provide a better user experience if you explicitly write a
41  * wrapper class that itself contains the template instantiation:
42  *
43  *    class CUDAGuard {
44  *    public:
45  *      // ... the API ...
46  *    private:
47  *      impl::InlineDeviceGuard<CUDAGuardImpl> guard_;
48  *    }
49  *
50  * The wrapper class provides a good place to write documentation, and helps
51  * avoid weird template instantiation errors when a user incorrectly uses the
52  * class.
53  *
54  * If you need to test this class, consider instantiating it with FakeGuardImpl.
55  */
56 template <typename T>
57 class InlineDeviceGuard {
58  public:
59   // Note [Omitted default constructor from RAII]
60   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
61   // In principle, we could add a default constructor to
62   // DeviceGuard which reads the current device and promises to
63   // restore to that device on exit.  However, most cases where you
64   // would have written this, you probably meant to actually just
65   // use OptionalDeviceGuard (since you don't actually need the
66   // restore to happen if you don't ever actually set the device).
67   // We remove the constructor here to encourage you to think about
68   // what you actually want to happen.
69   explicit InlineDeviceGuard() = delete;
70 
71   /// Set the current device to the passed Device.
InlineDeviceGuard(Device device)72   explicit InlineDeviceGuard(Device device)
73       : impl_(device.type()),
74         original_device_(
75             device.index() == -1 ? impl_.getDevice()
76                                  : impl_.exchangeDevice(device)),
77         current_device_(device.index() == -1 ? original_device_ : device) {}
78 
79   /// Set the current device index to the passed DeviceIndex.  (The
80   /// device type is inferred from the template parameter T).
81   template <
82       typename U = T,
83       typename =
84           typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
InlineDeviceGuard(DeviceIndex device_index)85   explicit InlineDeviceGuard(DeviceIndex device_index)
86       : InlineDeviceGuard(Device(U::static_type, device_index)) {}
87 
88   /// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit
89   /// DeviceGuardImplInterface pointer.
90   template <
91       typename U = T,
92       typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>>
InlineDeviceGuard(Device device,const DeviceGuardImplInterface * impl)93   explicit InlineDeviceGuard(
94       Device device,
95       const DeviceGuardImplInterface* impl)
96       : impl_(
97             VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))),
98         original_device_(
99             device.index() == -1 ? impl_.getDevice()
100                                  : impl_.exchangeDevice(device)),
101         current_device_(device.index() == -1 ? original_device_ : device) {}
102 
103   /// Copy is disallowed
104   InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete;
105   InlineDeviceGuard<T>& operator=(const InlineDeviceGuard<T>&) = delete;
106 
107   /// Move is disallowed, as DeviceGuard does not have an uninitialized state,
108   /// which is required for moves on types with nontrivial destructors.
109   InlineDeviceGuard(InlineDeviceGuard<T>&& other) = delete;
110   InlineDeviceGuard& operator=(InlineDeviceGuard<T>&& other) = delete;
111 
~InlineDeviceGuard()112   ~InlineDeviceGuard() {
113     impl_.uncheckedSetDevice(original_device_);
114   }
115 
116   /// Sets the device to the given one.
117   template <
118       typename U = T,
119       typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>, int> = 0>
set_device(at::Device device)120   void set_device(at::Device device) {
121     AT_ASSERT(
122         (U::static_type == DeviceType::HIP && device.is_cuda()) ||
123         device.type() == U::static_type);
124     auto index = device.index();
125     if (index == -1)
126       return;
127     impl_.setDevice(device);
128     current_device_ = device;
129   }
130 
131   /// Resets the currently set device to its original device, and then sets the
132   /// current device to the passed device.  This is effectively equivalent to
133   /// set_device when a guard supports only a single device type.
134   template <typename U = T>
reset_device(at::Device device)135   typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>> reset_device(
136       at::Device device) {
137     set_device(device);
138   }
139 
140   /// Resets the currently set device to its original device, and then sets the
141   /// current device to the passed device (for a possibly different device
142   /// type).
143   ///
144   /// This method is named reset_device to highlight the fact that previous
145   /// device settings from this guard are NOT preserved, even if the device
146   /// has a different device type.  For example:
147   ///
148   ///   // CUDA device is 0
149   ///   DeviceGuard g(Device(kCUDA, 1));
150   ///   g.reset_device(Device(kHIP, 2));
151   ///   // CUDA device is 0 (!!)
152   ///
153   /// NOTE: this implementation may skip some device setting if it can prove
154   /// that it is unnecessary.
155   ///
156   /// Optional argument is for testing only.
157   template <typename U = T>
158   typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>> reset_device(
159       at::Device device,
160       const impl::DeviceGuardImplInterface* impl = nullptr) {
161     auto index = device.index();
162     if (index == -1)
163       return;
164     if (device.type() == original_device_.type()) {
165       AT_ASSERT(impl == nullptr || impl->type() == device.type());
166       impl_.setDevice(device);
167       current_device_ = device;
168     } else {
169       // Destruct and reconstruct the DeviceGuard in place
170       impl_.setDevice(original_device_);
171       impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl);
172       original_device_ = impl_.exchangeDevice(device);
173       current_device_ = device;
174     }
175   }
176 
177   /// Sets the device index to the given one.  The device type is inferred
178   /// from the original device type.
set_index(DeviceIndex index)179   void set_index(DeviceIndex index) {
180     reset_device(Device(original_device_.type(), index));
181   }
182 
183   /// Returns the device that was set at the time the most recent
184   /// reset_device(), or otherwise the device at construction time.
original_device()185   Device original_device() const {
186     return original_device_;
187   }
188 
189   /// Returns the most recent device that was set using this device guard,
190   /// either from construction, or via set_device/reset_device/set_index.
current_device()191   Device current_device() const {
192     return current_device_;
193   }
194 
195  protected:
196   T impl_;
197 
198  private:
199   Device original_device_;
200   Device current_device_;
201 };
202 
203 /**
204  * A OptionalDeviceGuard is an RAII class that sets a device to some value on
205  * initialization, and resets the device to its original value on destruction.
206  *
207  * InlineOptionalDeviceGuard is a helper class for implementing
208  * OptionalDeviceGuards.  See guidance in InlineDeviceGuard on how to
209  * use this.  See OptionalDeviceGuard for user-oriented usage notes.
210  */
211 template <typename T>
212 class InlineOptionalDeviceGuard {
213  public:
214   // Note [Explicit initialization of optional fields]
215   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
216   // Explicit initialization of optional fields
217   // required to workaround an nvcc bug; see
218   // https://github.com/pytorch/pytorch/issues/12117
219 
220   /// Creates an uninitialized OptionalDeviceGuard.
InlineOptionalDeviceGuard()221   explicit InlineOptionalDeviceGuard()
222       : guard_() // See Note [Explicit initialization of optional fields]
223   {}
224 
225   /// Set the current device to the passed Device, if it is not nullopt.
InlineOptionalDeviceGuard(std::optional<Device> device_opt)226   explicit InlineOptionalDeviceGuard(std::optional<Device> device_opt)
227       : guard_() { // See Note [Explicit initialization of optional fields]
228     if (device_opt.has_value()) {
229       guard_.emplace(device_opt.value());
230     }
231   }
232 
233   /// Set the current device to the passed DeviceIndex, if it is not nullopt.
234   template <
235       typename U = T,
236       typename =
237           typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
InlineOptionalDeviceGuard(std::optional<DeviceIndex> device_index_opt)238   explicit InlineOptionalDeviceGuard(
239       std::optional<DeviceIndex> device_index_opt)
240       : guard_() { // See Note [Explicit initialization of optional fields]
241     if (device_index_opt.has_value()) {
242       guard_.emplace(device_index_opt.value());
243     }
244   }
245 
246   /// All constructors of DeviceGuard are valid for OptionalDeviceGuard
247   /// and result in initialized OptionalDeviceGuard.
248   template <typename... Args>
InlineOptionalDeviceGuard(Args &&...args)249   explicit InlineOptionalDeviceGuard(Args&&... args)
250       : guard_(std::in_place, std::forward<Args>(args)...) {}
251 
252   // TODO: Consider reading Tensor and TensorList constructors here, when
253   // Tensor moves to c10.  (These are only valid on OptionalDeviceGuard,
254   // because a Tensor may be undefined, in which case we need an uninitialized
255   // tensor guard.)
256 
257   // Note [Move construction for RAII guards is tricky]
258   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
259   // In principle, move construction is useful for terminating
260   // the lifetime of a `OptionalDeviceGuard` early; for example:
261   //
262   //     // current device is d0
263   //     OptionalDeviceGuard g1(d1);
264   //     // current device is d1
265   //     {
266   //       OptionalDeviceGuard g2(std::move(g1));
267   //     }
268   //     // current device is d0!!
269   //
270   // However, it's difficult to implement the move constructor
271   // in a way that works in all situations.  For example, consider
272   // the following example:
273   //
274   //     OptionalDeviceGuard g1(d1);
275   //     {
276   //       OptionalDeviceGuard g2(d2);
277   //       {
278   //         OptionalDeviceGuard g3(std::move(g1)); // !!!
279   //       }
280   //     }
281   //
282   // What should the current device be while g3 in scope... and what
283   // should it be after it goes out of scope?  What about g2?
284   // There don't seem to be satisfactory answers for these questions.
285   //
286   // It's in principle possible to raise an error when this occurs
287   // by doing some extra thread-local bookkeeping.  But why bother?
288   // Just don't provide the constructor.
289   InlineOptionalDeviceGuard(InlineOptionalDeviceGuard<T>&& other) = delete;
290 
291   // Note [Move assignment for RAII guards is tricky]
292   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
293   // Move assignment is deleted, because you need to know which guard was
294   // defined "first", as that guard's original_device_ wins--with the current
295   // representation, we have no way of telling which is the case.  (Move
296   // construction does not have this problem, as one guard is always
297   // uninitialized.)
298   //
299   // We can make this clear by way of a pair of examples:
300   //
301   // Example 1:
302   //
303   //  // initial device is n0
304   //  {
305   //    CUDAGuard g1(n1);
306   //    {
307   //      CUDAGuard g2(n2);
308   //      // current device should be n2
309   //      g1 = std::move(g2);
310   //      // current device should still be n2
311   //    }
312   //    // current device should still be n2
313   //  }
314   //  // current device should be n0
315   //
316   //  Example 2 (flip the order of the two guards):
317   //
318   //  // initial device is n0
319   //  {
320   //    CUDAGuard g2(n2);
321   //    {
322   //      CUDAGuard g1(n1);
323   //      // current device should be n1
324   //      g1 = std::move(g2);
325   //      // current device should be n2
326   //    }
327   //    // current device should be n0 (since g2 has been vacated)
328   //  }
329   //
330   // In both examples, we need g1 to restore to n0 after move assignment.
331   // However, in example 1, this is determined by the restore value of g1
332   // (prior to the move). In example 2, however, it is determined by the the
333   // restore value of g2(!!). We don't know which one should win, without having
334   // a way of telling which guard was allocated first.
335   //
336   // We could solve this with an extra thread-local variable.  But no one is
337   // actually using move-assignment.  So just get rid of it.
338   InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) =
339       delete;
340 
341   /// Sets the device to the given one.  Initializes OptionalDeviceGuard if it
342   /// is not already initialized.
343   template <
344       typename U = T,
345       typename =
346           typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
set_device(at::Device device)347   void set_device(at::Device device) {
348     if (!guard_.has_value()) {
349       guard_.emplace(device);
350     } else {
351       guard_->set_device(device);
352     }
353   }
354 
355   /// Resets the currently set device to its original device, and then sets the
356   /// current device to the passed device (for a possibly different device
357   /// type).  Initializes OptionalDeviceGuard if it is not already initialized.
358   ///
359   /// See notes on why this is called reset_device on InlineDeviceGuard.
360   ///
361   /// Optional argument is for testing only.
362   template <
363       typename U = T,
364       typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>>
365   void reset_device(
366       at::Device device,
367       const DeviceGuardImplInterface* impl = nullptr) {
368     if (!guard_.has_value()) {
369       guard_.emplace(device, impl);
370     } else {
371       guard_->reset_device(device, impl);
372     }
373   }
374 
375   /// Resets the currently set device to its original device, and then sets the
376   /// current device to the passed device.  Initializes the guard if it is
377   /// not already initialized.  This is effectively equivalent to set_device
378   /// when a guard supports only a single device type.
379   template <
380       typename U = T,
381       typename =
382           typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
reset_device(at::Device device)383   void reset_device(at::Device device) {
384     if (!guard_.has_value()) {
385       guard_.emplace(device);
386     } else {
387       guard_->reset_device(device);
388     }
389   }
390 
391   /// Sets the device index to the given one.  The device type is statically
392   /// known.
393   template <
394       typename U = T,
395       typename =
396           typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
set_index(DeviceIndex index)397   void set_index(DeviceIndex index) {
398     if (!guard_.has_value()) {
399       guard_.emplace(index);
400     } else {
401       guard_->set_index(index);
402     }
403   }
404 
405   /// Returns the device that was set immediately prior to initialization of
406   /// the, guard, or nullopt if the guard is uninitialized.
original_device()407   std::optional<Device> original_device() const {
408     return guard_.has_value() ? std::make_optional(guard_->original_device())
409                               : std::nullopt;
410   }
411 
412   /// Returns the most recent device that was set using this device guard,
413   /// either from construction, or via set_device, if the guard is initialized,
414   /// or nullopt if the guard is uninitialized.
current_device()415   std::optional<Device> current_device() const {
416     return guard_.has_value() ? std::make_optional(guard_->current_device())
417                               : std::nullopt;
418   }
419 
420   /// Restore the original device, resetting this guard to uninitialized state.
reset()421   void reset() {
422     guard_.reset();
423   }
424 
425  private:
426   std::optional<InlineDeviceGuard<T>> guard_;
427 };
428 
429 } // namespace c10::impl
430