xref: /aosp_15_r20/external/pytorch/c10/core/StreamGuard.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/Stream.h>
5 #include <c10/core/impl/InlineStreamGuard.h>
6 #include <c10/core/impl/VirtualGuardImpl.h>
7 #include <c10/util/ArrayRef.h>
8 #include <c10/util/Optional.h>
9 
10 namespace c10 {
11 
12 /**
13  * A StreamGuard is an RAII class that changes the current device
14  * to the device corresponding to some stream, and changes the
15  * default stream on that device to be this stream.
16  *
17  * Use of StreamGuard is HIGHLY discouraged in operator definitions.  In
18  * a single operator, you probably don't know enough about the global
19  * state of the world to profitably decide how to set streams.  Let
20  * the caller handle this appropriately, and just use the current stream
21  * in your operator code.
22  *
23  * This StreamGuard does NOT have an uninitialized state; it is guaranteed
24  * to reset the stream and device on exit.  If you are in a situation
25  * where you *might* want to setup a stream guard, see OptionalStreamGuard.
26  */
27 struct StreamGuard {
28   /// No default constructor, see Note [Omitted default constructor from RAII]
29   explicit StreamGuard() = delete;
30 
31   /// Set the current device to the device associated with the passed stream,
32   /// and set the current  stream on that device to the passed stream.
StreamGuardStreamGuard33   explicit StreamGuard(Stream stream) : guard_(stream) {}
34 
35   /// Copy is disallowed
36   StreamGuard(const StreamGuard&) = delete;
37   StreamGuard& operator=(const StreamGuard&) = delete;
38 
39   /// Move is disallowed, as StreamGuard does not have an uninitialized state,
40   /// which is required for moves on types with nontrivial destructors.
41   StreamGuard(StreamGuard&& other) = delete;
42   StreamGuard& operator=(StreamGuard&& other) = delete;
43 
44   /// Resets the currently set stream to the original stream and
45   /// the currently set device to the original device.  Then,
46   /// set the current device to the device associated with the passed stream,
47   /// and set the current stream on that device to the passed stream.
48   ///
49   /// NOTE: this implementation may skip some stream/device setting if
50   /// it can prove that it is unnecessary.
51   ///
52   /// WARNING: reset_stream does NOT preserve previously set streams on
53   /// different devices.  If you need to set streams on multiple devices
54   /// on , use MultiStreamGuard instead.
reset_streamStreamGuard55   void reset_stream(Stream stream) {
56     guard_.reset_stream(stream);
57   }
58 
59   /// Returns the stream that was set at the time the guard was constructed.
original_streamStreamGuard60   Stream original_stream() const {
61     return guard_.original_stream();
62   }
63 
64   /// Returns the most recent stream that was set using this device guard,
65   /// either from construction, or via set_stream.
current_streamStreamGuard66   Stream current_stream() const {
67     return guard_.current_stream();
68   }
69 
70   /// Returns the most recent device that was set using this device guard,
71   /// either from construction, or via set_device/reset_device/set_index.
current_deviceStreamGuard72   Device current_device() const {
73     return guard_.current_device();
74   }
75 
76   /// Returns the device that was set at the most recent reset_stream(),
77   /// or otherwise the device at construction time.
original_deviceStreamGuard78   Device original_device() const {
79     return guard_.original_device();
80   }
81 
82  private:
83   c10::impl::InlineStreamGuard<impl::VirtualGuardImpl> guard_;
84 };
85 
86 /**
87  * An OptionalStreamGuard is an RAII class that sets a device to some value on
88  * initialization, and resets the device to its original value on destruction.
89  * See OptionalDeviceGuard for more guidance on how to use this class.
90  */
91 struct OptionalStreamGuard {
92   /// Create an uninitialized guard.
93   explicit OptionalStreamGuard() = default;
94 
95   /// Set the current device to the device associated with the passed stream,
96   /// and set the current stream on that device to the passed stream.
OptionalStreamGuardOptionalStreamGuard97   explicit OptionalStreamGuard(Stream stream) : guard_(stream) {}
98 
99   /// Set the current device to the device associated with the passed stream,
100   /// and set the current stream on that device to the passed stream,
101   /// if the passed stream is not nullopt.
OptionalStreamGuardOptionalStreamGuard102   explicit OptionalStreamGuard(std::optional<Stream> stream_opt)
103       : guard_(stream_opt) {}
104 
105   /// Copy is disallowed
106   OptionalStreamGuard(const OptionalStreamGuard&) = delete;
107   OptionalStreamGuard& operator=(const OptionalStreamGuard&) = delete;
108 
109   // See Note [Move construction for RAII guards is tricky]
110   OptionalStreamGuard(OptionalStreamGuard&& other) = delete;
111 
112   // See Note [Move assignment for RAII guards is tricky]
113   OptionalStreamGuard& operator=(OptionalStreamGuard&& other) = delete;
114 
115   /// Resets the currently set stream to the original stream and
116   /// the currently set device to the original device.  Then,
117   /// set the current device to the device associated with the passed stream,
118   /// and set the current stream on that device to the passed stream.
119   /// Initializes the guard if it was not previously initialized.
reset_streamOptionalStreamGuard120   void reset_stream(Stream stream) {
121     guard_.reset_stream(stream);
122   }
123 
124   /// Returns the stream that was set at the time the guard was most recently
125   /// initialized, or nullopt if the guard is uninitialized.
original_streamOptionalStreamGuard126   std::optional<Stream> original_stream() const {
127     return guard_.original_stream();
128   }
129 
130   /// Returns the most recent  stream that was set using this stream guard,
131   /// either from construction, or via reset_stream, if the guard is
132   /// initialized, or nullopt if the guard is uninitialized.
current_streamOptionalStreamGuard133   std::optional<Stream> current_stream() const {
134     return guard_.current_stream();
135   }
136 
137   /// Restore the original  device and stream, resetting this guard to
138   /// uninitialized state.
resetOptionalStreamGuard139   void reset() {
140     guard_.reset();
141   }
142 
143  private:
144   c10::impl::InlineOptionalStreamGuard<impl::VirtualGuardImpl> guard_{};
145 };
146 
147 /**
148  * A MultiStreamGuard is an RAII class that sets the current streams of a set of
149  * devices all at once, and resets them to their original values on destruction.
150  */
151 struct MultiStreamGuard {
152   /// Set the current streams to the passed streams on each of their respective
153   /// devices.
MultiStreamGuardMultiStreamGuard154   explicit MultiStreamGuard(ArrayRef<Stream> streams) : guard_(streams) {}
155 
156   /// Copy is disallowed
157   MultiStreamGuard(const MultiStreamGuard&) = delete;
158   MultiStreamGuard& operator=(const MultiStreamGuard&) = delete;
159 
160   // See Note [Move construction for RAII guards is tricky]
161   MultiStreamGuard(MultiStreamGuard&& other) = delete;
162 
163   // See Note [Move assignment for RAII guards is tricky]
164   MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete;
165 
166  private:
167   c10::impl::InlineMultiStreamGuard<impl::VirtualGuardImpl> guard_;
168 };
169 
170 } // namespace c10
171