xref: /aosp_15_r20/external/pytorch/c10/core/impl/InlineStreamGuard.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/InlineDeviceGuard.h>
4 #include <c10/util/ArrayRef.h>
5 #include <c10/util/irange.h>
6 
7 namespace c10::impl {
8 
9 /**
10  * A StreamGuard is an RAII class that changes the current device
11  * to the device corresponding to some stream, and changes the
12  * default stream on that device to be this stream.
13  *
14  * InlineStreamGuard is a helper class for implementing StreamGuards.
15  * See InlineDeviceGuard for guidance on how to use this class.
16  */
17 template <typename T>
18 class InlineStreamGuard : private InlineDeviceGuard<T> {
19  public:
20   /// No default constructor, see Note [Omitted default constructor from RAII]
21   explicit InlineStreamGuard() = delete;
22 
23   /// Set the current device to the device associated with the passed stream,
24   /// and set the current stream on that device to the passed stream.
InlineStreamGuard(Stream stream)25   explicit InlineStreamGuard(Stream stream)
26       : InlineDeviceGuard<T>(stream.device()),
27         original_stream_of_original_device_(
28             this->impl_.getStream(original_device())),
29         original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
30         current_stream_(stream) {}
31 
32   /// This constructor exists purely for testing
33   template <
34       typename U = T,
35       typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>>
InlineStreamGuard(Stream stream,const DeviceGuardImplInterface * impl)36   explicit InlineStreamGuard(
37       Stream stream,
38       const DeviceGuardImplInterface* impl)
39       : InlineDeviceGuard<T>(
40             stream.device(),
41             impl ? impl : getDeviceGuardImpl(stream.device_type())),
42         original_stream_of_original_device_(
43             this->impl_.getStream(original_device())),
44         original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
45         current_stream_(stream) {}
46 
47   /// Copy is disallowed
48   InlineStreamGuard(const InlineStreamGuard<T>&) = delete;
49   InlineStreamGuard<T>& operator=(const InlineStreamGuard<T>&) = delete;
50 
51   /// Move is disallowed, as StreamGuard does not have an uninitialized state,
52   /// which is required for moves on types with nontrivial destructors.
53   InlineStreamGuard(InlineStreamGuard<T>&& other) = delete;
54   InlineStreamGuard& operator=(InlineStreamGuard<T>&& other) = delete;
55 
~InlineStreamGuard()56   ~InlineStreamGuard() {
57     this->impl_.exchangeStream(original_stream_of_current_device_);
58   }
59 
60   /// Resets the currently set stream to the original stream and
61   /// the currently set device to the original device.  Then,
62   /// set the current device to the device associated with the passed stream,
63   /// and set the current stream on that device to the passed stream.
64   ///
65   /// NOTE: this implementation may skip some stream/device setting if
66   /// it can prove that it is unnecessary.
67   ///
68   /// WARNING: reset_stream does NOT preserve previously set streams on
69   /// different devices.  If you need to set streams on multiple devices
70   /// use MultiStreamGuard instead.
reset_stream(Stream stream)71   void reset_stream(Stream stream) {
72     // TODO: make a version that takes an impl argument.  Unfortunately,
73     // that will require SFINAE because impl is only valid for the
74     // VirtualGuardImpl specialization.
75     if (stream.device() == this->current_device()) {
76       this->impl_.exchangeStream(stream);
77       current_stream_ = stream;
78     } else {
79       // Destruct and reconstruct the StreamGuard in-place
80       this->impl_.exchangeStream(original_stream_of_current_device_);
81       this->reset_device(stream.device());
82       original_stream_of_current_device_ = this->impl_.exchangeStream(stream);
83       current_stream_ = stream;
84     }
85   }
86 
87   // It's not clear if set_device should also reset the current stream
88   // if the device is unchanged; therefore, we don't provide it.
89   // The situation is somewhat clearer with reset_device, but it's still
90   // a pretty weird thing to do, so haven't added this either.
91 
92   /// Returns the stream of the original device prior to this guard.  Subtly,
93   /// the stream returned here is the original stream of the *original*
94   /// device; i.e., it's the stream that your computation *would* have
95   /// been put on, if it hadn't been for this meddling stream guard.
96   /// This is usually what you want.
original_stream()97   Stream original_stream() const {
98     return original_stream_of_original_device_;
99   }
100 
101   /// Returns the most recent stream that was set using this device guard,
102   /// either from construction, or via set_stream.
current_stream()103   Stream current_stream() const {
104     return current_stream_;
105   }
106 
107   /// Returns the most recent device that was set using this device guard,
108   /// either from construction, or via set_device/reset_device/set_index.
current_device()109   Device current_device() const {
110     return InlineDeviceGuard<T>::current_device();
111   }
112 
113   /// Returns the device that was set at the most recent reset_stream(),
114   /// or otherwise the device at construction time.
original_device()115   Device original_device() const {
116     return InlineDeviceGuard<T>::original_device();
117   }
118 
119  private:
120   Stream
121       original_stream_of_original_device_; // what the user probably cares about
122   Stream original_stream_of_current_device_; // what we need to restore
123   Stream current_stream_;
124 };
125 
126 /**
127  * An OptionalStreamGuard is an RAII class that sets a device to some value on
128  * initialization, and resets the device to its original value on destruction.
129  * See InlineOptionalDeviceGuard for more guidance on how to use this class.
130  */
131 template <typename T>
132 class InlineOptionalStreamGuard {
133  public:
134   /// Creates an uninitialized stream guard.
InlineOptionalStreamGuard()135   explicit InlineOptionalStreamGuard()
136       : guard_() // See Note [Explicit initialization of optional fields]
137   {}
138 
139   /// Set the current device to the device associated with the passed stream,
140   /// and set the current stream on that device to the passed stream,
141   /// if the passed stream is not nullopt.
InlineOptionalStreamGuard(std::optional<Stream> stream_opt)142   explicit InlineOptionalStreamGuard(std::optional<Stream> stream_opt)
143       : guard_() {
144     if (stream_opt.has_value()) {
145       guard_.emplace(stream_opt.value());
146     }
147   }
148 
149   /// All constructors of StreamGuard are valid for OptionalStreamGuard
150   template <typename... Args>
InlineOptionalStreamGuard(Args &&...args)151   explicit InlineOptionalStreamGuard(Args&&... args)
152       : guard_(std::in_place, std::forward<Args>(args)...) {}
153 
154   // See Note [Move construction for RAII guards is tricky]
155   InlineOptionalStreamGuard(InlineOptionalStreamGuard<T>&& other) = delete;
156 
157   // See Note [Move assignment for RAII guards is tricky]
158   InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) =
159       delete;
160 
161   /// Resets the currently set stream to the original stream and
162   /// the currently set device to the original device.  Then,
163   /// set the current device to the device associated with the passed stream,
164   /// and set the current stream on that device to the passed stream.
165   /// Initializes the OptionalStreamGuard if it was not previously initialized.
reset_stream(Stream stream)166   void reset_stream(Stream stream) {
167     if (guard_.has_value()) {
168       guard_->reset_stream(stream);
169     } else {
170       guard_.emplace(stream);
171     }
172   }
173 
174   /// Returns the stream that was set at the time the guard was most recently
175   /// initialized, or nullopt if the guard is uninitialized.
original_stream()176   std::optional<Stream> original_stream() const {
177     return guard_.has_value() ? std::make_optional(guard_->original_stream())
178                               : std::nullopt;
179   }
180 
181   /// Returns the most recent stream that was set using this stream guard,
182   /// either from construction, or via reset_stream, if the guard is
183   /// initialized, or nullopt if the guard is uninitialized.
current_stream()184   std::optional<Stream> current_stream() const {
185     return guard_.has_value() ? std::make_optional(guard_->current_stream())
186                               : std::nullopt;
187   }
188 
189   /// Restore the original device and stream, resetting this guard to
190   /// uninitialized state.
reset()191   void reset() {
192     guard_.reset();
193   }
194 
195  private:
196   std::optional<InlineStreamGuard<T>> guard_;
197 };
198 
199 template <typename T>
200 class InlineMultiStreamGuard {
201  public:
202   /// Calls `set_stream` on each of the streams in the list.
203   /// This may be useful if you need to set different streams
204   /// for different devices.
InlineMultiStreamGuard(ArrayRef<Stream> streams)205   explicit InlineMultiStreamGuard(ArrayRef<Stream> streams) {
206     if (!streams.empty()) {
207       impl_.emplace(getDeviceTypeOfStreams(streams));
208       original_streams_.reserve(streams.size());
209       for (const Stream& s : streams) {
210         original_streams_.emplace_back(this->impl_->exchangeStream(s));
211       }
212     }
213   }
214 
215   /// Copy is disallowed
216   InlineMultiStreamGuard(const InlineMultiStreamGuard&) = delete;
217   InlineMultiStreamGuard<T>& operator=(const InlineMultiStreamGuard&) = delete;
218 
219   /// Move is disallowed, as StreamGuard does not have an uninitialized state,
220   /// which is required for moves on types with nontrivial destructors.
221   InlineMultiStreamGuard(InlineMultiStreamGuard&& other) = delete;
222   InlineMultiStreamGuard& operator=(InlineMultiStreamGuard&& other) = delete;
223 
~InlineMultiStreamGuard()224   ~InlineMultiStreamGuard() noexcept {
225     if (this->impl_.has_value()) {
226       for (const Stream& s : original_streams_) {
227         this->impl_->exchangeStream(s);
228       }
229     }
230   }
231 
232  protected:
233   std::optional<T> impl_;
234 
235  private:
236   /// The original streams that were active on all devices.
237   std::vector<Stream> original_streams_;
238 
getDeviceTypeOfStreams(ArrayRef<Stream> streams)239   static DeviceType getDeviceTypeOfStreams(ArrayRef<Stream> streams) {
240     TORCH_INTERNAL_ASSERT(!streams.empty());
241     DeviceType type = streams[0].device_type();
242     for (const auto idx : c10::irange(1, streams.size())) {
243       TORCH_CHECK_VALUE(
244           streams[idx].device_type() == type,
245           "Streams have a mix of device types: stream 0 is on ",
246           streams[0].device(),
247           " while stream ",
248           idx,
249           " is on device ",
250           streams[idx].device());
251     }
252     return type;
253   }
254 };
255 
256 } // namespace c10::impl
257