1 #pragma once 2 3 #include <c10/core/impl/InlineDeviceGuard.h> 4 #include <c10/util/ArrayRef.h> 5 #include <c10/util/irange.h> 6 7 namespace c10::impl { 8 9 /** 10 * A StreamGuard is an RAII class that changes the current device 11 * to the device corresponding to some stream, and changes the 12 * default stream on that device to be this stream. 13 * 14 * InlineStreamGuard is a helper class for implementing StreamGuards. 15 * See InlineDeviceGuard for guidance on how to use this class. 16 */ 17 template <typename T> 18 class InlineStreamGuard : private InlineDeviceGuard<T> { 19 public: 20 /// No default constructor, see Note [Omitted default constructor from RAII] 21 explicit InlineStreamGuard() = delete; 22 23 /// Set the current device to the device associated with the passed stream, 24 /// and set the current stream on that device to the passed stream. InlineStreamGuard(Stream stream)25 explicit InlineStreamGuard(Stream stream) 26 : InlineDeviceGuard<T>(stream.device()), 27 original_stream_of_original_device_( 28 this->impl_.getStream(original_device())), 29 original_stream_of_current_device_(this->impl_.exchangeStream(stream)), 30 current_stream_(stream) {} 31 32 /// This constructor exists purely for testing 33 template < 34 typename U = T, 35 typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>> InlineStreamGuard(Stream stream,const DeviceGuardImplInterface * impl)36 explicit InlineStreamGuard( 37 Stream stream, 38 const DeviceGuardImplInterface* impl) 39 : InlineDeviceGuard<T>( 40 stream.device(), 41 impl ? impl : getDeviceGuardImpl(stream.device_type())), 42 original_stream_of_original_device_( 43 this->impl_.getStream(original_device())), 44 original_stream_of_current_device_(this->impl_.exchangeStream(stream)), 45 current_stream_(stream) {} 46 47 /// Copy is disallowed 48 InlineStreamGuard(const InlineStreamGuard<T>&) = delete; 49 InlineStreamGuard<T>& operator=(const InlineStreamGuard<T>&) = delete; 50 51 /// Move is disallowed, as StreamGuard does not have an uninitialized state, 52 /// which is required for moves on types with nontrivial destructors. 53 InlineStreamGuard(InlineStreamGuard<T>&& other) = delete; 54 InlineStreamGuard& operator=(InlineStreamGuard<T>&& other) = delete; 55 ~InlineStreamGuard()56 ~InlineStreamGuard() { 57 this->impl_.exchangeStream(original_stream_of_current_device_); 58 } 59 60 /// Resets the currently set stream to the original stream and 61 /// the currently set device to the original device. Then, 62 /// set the current device to the device associated with the passed stream, 63 /// and set the current stream on that device to the passed stream. 64 /// 65 /// NOTE: this implementation may skip some stream/device setting if 66 /// it can prove that it is unnecessary. 67 /// 68 /// WARNING: reset_stream does NOT preserve previously set streams on 69 /// different devices. If you need to set streams on multiple devices 70 /// use MultiStreamGuard instead. reset_stream(Stream stream)71 void reset_stream(Stream stream) { 72 // TODO: make a version that takes an impl argument. Unfortunately, 73 // that will require SFINAE because impl is only valid for the 74 // VirtualGuardImpl specialization. 75 if (stream.device() == this->current_device()) { 76 this->impl_.exchangeStream(stream); 77 current_stream_ = stream; 78 } else { 79 // Destruct and reconstruct the StreamGuard in-place 80 this->impl_.exchangeStream(original_stream_of_current_device_); 81 this->reset_device(stream.device()); 82 original_stream_of_current_device_ = this->impl_.exchangeStream(stream); 83 current_stream_ = stream; 84 } 85 } 86 87 // It's not clear if set_device should also reset the current stream 88 // if the device is unchanged; therefore, we don't provide it. 89 // The situation is somewhat clearer with reset_device, but it's still 90 // a pretty weird thing to do, so haven't added this either. 91 92 /// Returns the stream of the original device prior to this guard. Subtly, 93 /// the stream returned here is the original stream of the *original* 94 /// device; i.e., it's the stream that your computation *would* have 95 /// been put on, if it hadn't been for this meddling stream guard. 96 /// This is usually what you want. original_stream()97 Stream original_stream() const { 98 return original_stream_of_original_device_; 99 } 100 101 /// Returns the most recent stream that was set using this device guard, 102 /// either from construction, or via set_stream. current_stream()103 Stream current_stream() const { 104 return current_stream_; 105 } 106 107 /// Returns the most recent device that was set using this device guard, 108 /// either from construction, or via set_device/reset_device/set_index. current_device()109 Device current_device() const { 110 return InlineDeviceGuard<T>::current_device(); 111 } 112 113 /// Returns the device that was set at the most recent reset_stream(), 114 /// or otherwise the device at construction time. original_device()115 Device original_device() const { 116 return InlineDeviceGuard<T>::original_device(); 117 } 118 119 private: 120 Stream 121 original_stream_of_original_device_; // what the user probably cares about 122 Stream original_stream_of_current_device_; // what we need to restore 123 Stream current_stream_; 124 }; 125 126 /** 127 * An OptionalStreamGuard is an RAII class that sets a device to some value on 128 * initialization, and resets the device to its original value on destruction. 129 * See InlineOptionalDeviceGuard for more guidance on how to use this class. 130 */ 131 template <typename T> 132 class InlineOptionalStreamGuard { 133 public: 134 /// Creates an uninitialized stream guard. InlineOptionalStreamGuard()135 explicit InlineOptionalStreamGuard() 136 : guard_() // See Note [Explicit initialization of optional fields] 137 {} 138 139 /// Set the current device to the device associated with the passed stream, 140 /// and set the current stream on that device to the passed stream, 141 /// if the passed stream is not nullopt. InlineOptionalStreamGuard(std::optional<Stream> stream_opt)142 explicit InlineOptionalStreamGuard(std::optional<Stream> stream_opt) 143 : guard_() { 144 if (stream_opt.has_value()) { 145 guard_.emplace(stream_opt.value()); 146 } 147 } 148 149 /// All constructors of StreamGuard are valid for OptionalStreamGuard 150 template <typename... Args> InlineOptionalStreamGuard(Args &&...args)151 explicit InlineOptionalStreamGuard(Args&&... args) 152 : guard_(std::in_place, std::forward<Args>(args)...) {} 153 154 // See Note [Move construction for RAII guards is tricky] 155 InlineOptionalStreamGuard(InlineOptionalStreamGuard<T>&& other) = delete; 156 157 // See Note [Move assignment for RAII guards is tricky] 158 InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = 159 delete; 160 161 /// Resets the currently set stream to the original stream and 162 /// the currently set device to the original device. Then, 163 /// set the current device to the device associated with the passed stream, 164 /// and set the current stream on that device to the passed stream. 165 /// Initializes the OptionalStreamGuard if it was not previously initialized. reset_stream(Stream stream)166 void reset_stream(Stream stream) { 167 if (guard_.has_value()) { 168 guard_->reset_stream(stream); 169 } else { 170 guard_.emplace(stream); 171 } 172 } 173 174 /// Returns the stream that was set at the time the guard was most recently 175 /// initialized, or nullopt if the guard is uninitialized. original_stream()176 std::optional<Stream> original_stream() const { 177 return guard_.has_value() ? std::make_optional(guard_->original_stream()) 178 : std::nullopt; 179 } 180 181 /// Returns the most recent stream that was set using this stream guard, 182 /// either from construction, or via reset_stream, if the guard is 183 /// initialized, or nullopt if the guard is uninitialized. current_stream()184 std::optional<Stream> current_stream() const { 185 return guard_.has_value() ? std::make_optional(guard_->current_stream()) 186 : std::nullopt; 187 } 188 189 /// Restore the original device and stream, resetting this guard to 190 /// uninitialized state. reset()191 void reset() { 192 guard_.reset(); 193 } 194 195 private: 196 std::optional<InlineStreamGuard<T>> guard_; 197 }; 198 199 template <typename T> 200 class InlineMultiStreamGuard { 201 public: 202 /// Calls `set_stream` on each of the streams in the list. 203 /// This may be useful if you need to set different streams 204 /// for different devices. InlineMultiStreamGuard(ArrayRef<Stream> streams)205 explicit InlineMultiStreamGuard(ArrayRef<Stream> streams) { 206 if (!streams.empty()) { 207 impl_.emplace(getDeviceTypeOfStreams(streams)); 208 original_streams_.reserve(streams.size()); 209 for (const Stream& s : streams) { 210 original_streams_.emplace_back(this->impl_->exchangeStream(s)); 211 } 212 } 213 } 214 215 /// Copy is disallowed 216 InlineMultiStreamGuard(const InlineMultiStreamGuard&) = delete; 217 InlineMultiStreamGuard<T>& operator=(const InlineMultiStreamGuard&) = delete; 218 219 /// Move is disallowed, as StreamGuard does not have an uninitialized state, 220 /// which is required for moves on types with nontrivial destructors. 221 InlineMultiStreamGuard(InlineMultiStreamGuard&& other) = delete; 222 InlineMultiStreamGuard& operator=(InlineMultiStreamGuard&& other) = delete; 223 ~InlineMultiStreamGuard()224 ~InlineMultiStreamGuard() noexcept { 225 if (this->impl_.has_value()) { 226 for (const Stream& s : original_streams_) { 227 this->impl_->exchangeStream(s); 228 } 229 } 230 } 231 232 protected: 233 std::optional<T> impl_; 234 235 private: 236 /// The original streams that were active on all devices. 237 std::vector<Stream> original_streams_; 238 getDeviceTypeOfStreams(ArrayRef<Stream> streams)239 static DeviceType getDeviceTypeOfStreams(ArrayRef<Stream> streams) { 240 TORCH_INTERNAL_ASSERT(!streams.empty()); 241 DeviceType type = streams[0].device_type(); 242 for (const auto idx : c10::irange(1, streams.size())) { 243 TORCH_CHECK_VALUE( 244 streams[idx].device_type() == type, 245 "Streams have a mix of device types: stream 0 is on ", 246 streams[0].device(), 247 " while stream ", 248 idx, 249 " is on device ", 250 streams[idx].device()); 251 } 252 return type; 253 } 254 }; 255 256 } // namespace c10::impl 257