1 #pragma once
2
3 #include <c10/core/Device.h>
4 #include <c10/core/DeviceType.h>
5 #include <c10/core/Stream.h>
6 #include <c10/util/Exception.h>
7
8 // Just for C10_ANONYMOUS_VARIABLE
9 #include <c10/util/Registry.h>
10
11 #include <atomic>
12
13 namespace c10 {
14
15 // Forward declaration
16 class DataPtr;
17
18 /**
19 * Note [Flags defining the behavior of events]
20 *
21 * PYTORCH_DEFAULT and BACKEND_DEFAULT are valid for all backends. The
22 * BACKEND_DEFAULT is what a particular backend would select if no
23 * flags were given. PYTORCH_DEFAULT is the PyTorch's framework default
24 * choice for events on that backend, which may not be the same.
25 *
26 * The mapping of PYTORCH_DEFAULT and BACKEND_DEFAULT is done by each
27 * backend implementation.
28 */
29 enum class EventFlag {
30 // Disable timing
31 PYTORCH_DEFAULT,
32 // Enable timing
33 BACKEND_DEFAULT,
34 // FOR TESTING ONLY
35 INVALID
36 };
37
38 namespace impl {
39
40 /**
41 * DeviceGuardImplInterface represents the virtual interface which provides
42 * functionality to provide an RAII class for device and stream switching,
43 * via DeviceGuard. Every distinct device type, e.g., CUDA and HIP, is
44 * expected to implement and register an implementation of this interface.
45 * All classes which inherit from DeviceGuardImplInterface should be declared
46 * 'final'.
47 *
48 * This class exists because we provide a unified interface for performing
49 * device guards via DeviceGuard, but we cannot assume that we have actually
50 * compiled against the, e.g., CUDA library, which actually implements
51 * this guard functionality. In this case, a dynamic dispatch is required
52 * to cross the library boundary.
53 *
54 * If possible, you should directly use implementations of this interface;
55 * those uses will be devirtualized.
56 */
57 struct C10_API DeviceGuardImplInterface {
58 DeviceGuardImplInterface() = default;
59 DeviceGuardImplInterface(const DeviceGuardImplInterface&) = default;
60 DeviceGuardImplInterface& operator=(const DeviceGuardImplInterface&) =
61 default;
62 DeviceGuardImplInterface(DeviceGuardImplInterface&&) noexcept = default;
63 DeviceGuardImplInterface& operator=(DeviceGuardImplInterface&&) noexcept =
64 default;
65
66 /**
67 * Return the type of device managed by this guard implementation.
68 */
69 virtual DeviceType type() const = 0;
70
71 /**
72 * Set the current device to Device, and return the previous Device.
73 */
74 virtual Device exchangeDevice(Device) const = 0;
75 // NB: Implementations of exchangeDevice can be a bit boilerplatey. You might
76 // consider replacing exchangeDevice with a non-virtual function with a baked
77 // in implementation; however, note that this will triple the number of
78 // virtual calls (when you implement exchangeDevice in a final subclass,
79 // the compiler gets to devirtualize everything; it won't do that if you don't
80 // define it in the subclass!) A common way to solve this problem is to use
81 // some sort of CRTP; however, we can template DeviceGuardImplInterface since
82 // we really *do* need it to be virtual. A little boilerplate seems easiest
83 // to explain. (Another way around this problem is to provide inline
84 // functions that provide the default implementations, but this seems a little
85 // hard to explain. In any case, we're only going to have on order of ten
86 // implementations of this anyway.)
87
88 /**
89 * Get the current device.
90 */
91 virtual Device getDevice() const = 0;
92
93 /**
94 * Set the current device to Device.
95 */
96 virtual void setDevice(Device) const = 0;
97
98 /**
99 * Set the current device to Device, without checking for errors
100 * (so, e.g., this can be called from a destructor).
101 */
102 virtual void uncheckedSetDevice(Device) const noexcept = 0;
103
104 /**
105 * Get the current stream for a given device.
106 */
107 virtual Stream getStream(Device) const noexcept = 0;
108
109 /**
110 * Get the default stream for a given device.
111 */
getDefaultStreamDeviceGuardImplInterface112 virtual Stream getDefaultStream(Device) const {
113 TORCH_CHECK(false, "Backend doesn't support acquiring a default stream.")
114 }
115
116 /**
117 * Get a stream from the global pool for a given device.
118 */
119 virtual Stream getStreamFromGlobalPool(Device, bool isHighPriority = false)
120 const {
121 (void)isHighPriority; // Suppress unused variable warning
122 TORCH_CHECK(false, "Backend doesn't support acquiring a stream from pool.")
123 }
124
125 /**
126 * Return a new stream for a given device and priority. The stream will be
127 * copied and shared around, device backend should be able to correctly handle
128 * the lifetime of the stream.
129 */
130 virtual Stream getNewStream(Device, int priority = 0) const {
131 (void)priority;
132 TORCH_CHECK(false, "Backend doesn't support create a new Stream.")
133 }
134
135 /**
136 * Set a stream to be the thread local current stream for its device.
137 * Return the previous stream for that device. You are NOT required
138 * to set the current device to match the device of this stream.
139 */
140 virtual Stream exchangeStream(Stream) const noexcept = 0;
141
142 /**
143 * Destroys the given event.
144 */
destroyEventDeviceGuardImplInterface145 virtual void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
146 const noexcept {}
147
148 /**
149 * Increments the event's version and enqueues a job with this version
150 * in the stream's work queue. When the stream process that job
151 * it notifies all streams waiting on / blocked by that version of the
152 * event to continue and marks that version as recorded.
153 * */
recordDeviceGuardImplInterface154 virtual void record(
155 void** /*event*/,
156 const Stream& /*stream*/,
157 const DeviceIndex /*device_index*/,
158 const c10::EventFlag /*flag*/) const {
159 TORCH_CHECK(false, "Backend doesn't support events.");
160 }
161
162 /**
163 * Does nothing if the event has not been scheduled to be recorded.
164 * If the event was previously enqueued to be recorded, a command
165 * to wait for the version of the event that exists at the time of this call
166 * is inserted in the stream's work queue.
167 * When the stream reaches this command it will stop processing
168 * additional commands until that version of the event is marked as recorded.
169 */
blockDeviceGuardImplInterface170 virtual void block(void* /*event*/, const Stream& /*stream*/) const {
171 TORCH_CHECK(false, "Backend doesn't support events.");
172 }
173
174 /**
175 * Returns true if (and only if)
176 * (1) the event has never been scheduled to be recorded
177 * (2) the current version is marked as recorded.
178 * Returns false otherwise.
179 */
queryEventDeviceGuardImplInterface180 virtual bool queryEvent(void* /*event*/) const {
181 TORCH_CHECK(false, "Backend doesn't support events.");
182 }
183
184 /**
185 * Get the number of devices. WARNING: This is REQUIRED to not raise
186 * an exception. If there is some sort of problem, e.g., driver error,
187 * you should report that there are zero available devices.
188 */
189 virtual DeviceIndex deviceCount() const noexcept = 0;
190
191 /**
192 * Return true if all the work previously enqueued on the stream for
193 * asynchronous execution has completed running on the device.
194 */
queryStreamDeviceGuardImplInterface195 virtual bool queryStream(const Stream& /*stream*/) const {
196 TORCH_CHECK(false, "Backend doesn't support querying streams.");
197 }
198
199 /**
200 * Wait (by blocking the calling thread) until all the work previously
201 * enqueued on the stream has completed running on the device.
202 */
synchronizeStreamDeviceGuardImplInterface203 virtual void synchronizeStream(const Stream& /*stream*/) const {
204 TORCH_CHECK(false, "Backend doesn't support synchronizing streams.");
205 }
206
207 /**
208 * Wait (by blocking the calling thread) until all the work previously
209 * recorded on the event has completed running on the device.
210 */
synchronizeEventDeviceGuardImplInterface211 virtual void synchronizeEvent(void* /*event*/) const {
212 TORCH_CHECK(false, "Backend doesn't support synchronizing events.");
213 }
214
215 /**
216 * Ensure the caching allocator (if any) is aware that the given DataPtr is
217 * being used on the given stream, and that it should thus avoid recycling the
218 * DataPtr until all work on that stream is done.
219 */
recordDataPtrOnStreamDeviceGuardImplInterface220 virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const {
221 }
222
223 /**
224 * Fetch the elapsed time between two recorded events.
225 */
elapsedTimeDeviceGuardImplInterface226 virtual double elapsedTime(
227 void* /*event1*/,
228 void* /*event2*/,
229 const DeviceIndex /*device_index*/) const {
230 TORCH_CHECK(false, "Backend doesn't support elapsedTime.");
231 }
232
233 /**
234 * Intended use of this class is to leak the DeviceGuardImpl at program end.
235 * So you better not call the destructor, buster!
236 */
237 virtual ~DeviceGuardImplInterface() = default;
238 };
239
240 // A no-op device guard impl that doesn't do anything interesting. Useful
241 // for devices that don't actually have a concept of device index. Prominent
242 // examples are CPU and Meta.
243 template <DeviceType D>
244 struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
245 NoOpDeviceGuardImpl() = default;
typefinal246 DeviceType type() const override {
247 return D;
248 }
exchangeDevicefinal249 Device exchangeDevice(Device) const override {
250 return Device(D, -1); // no-op
251 }
getDevicefinal252 Device getDevice() const override {
253 return Device(D, -1);
254 }
setDevicefinal255 void setDevice(Device) const override {
256 // no-op
257 }
uncheckedSetDevicefinal258 void uncheckedSetDevice(Device) const noexcept override {
259 // no-op
260 }
getStreamfinal261 Stream getStream(Device) const noexcept override {
262 // no-op
263 return Stream(Stream::DEFAULT, Device(D, -1));
264 }
265
266 Stream getNewStream(Device, int priority = 0) const override {
267 // no-op
268 (void)priority;
269 return Stream(Stream::DEFAULT, Device(D, -1));
270 }
271
272 // NB: These do NOT set the current device
exchangeStreamfinal273 Stream exchangeStream(Stream) const noexcept override {
274 // no-op
275 return Stream(Stream::DEFAULT, Device(D, -1));
276 }
deviceCountfinal277 DeviceIndex deviceCount() const noexcept override {
278 return 1;
279 }
280
281 // Event-related functions
recordfinal282 void record(
283 void** /*event*/,
284 const Stream& /*stream*/,
285 const DeviceIndex /*device_index*/,
286 const EventFlag /*flag*/) const override {
287 TORCH_CHECK(false, D, " backend doesn't support events.");
288 }
blockfinal289 void block(void* /*event*/, const Stream& /*stream*/) const override {
290 TORCH_CHECK(false, D, " backend doesn't support events.")
291 }
queryEventfinal292 bool queryEvent(void* /*event*/) const override {
293 TORCH_CHECK(false, D, " backend doesn't support events.")
294 }
destroyEventfinal295 void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
296 const noexcept override {}
297
298 // Stream-related functions
queryStreamfinal299 bool queryStream(const Stream& /*stream*/) const override {
300 return true;
301 }
synchronizeStreamfinal302 void synchronizeStream(const Stream& /*stream*/) const override {
303 // Don't wait for anything.
304 }
305 };
306
307 // The registry is NON-owning. Each stored pointer is std::atomic so
308 // that under all interleavings of registry calls the structure is
309 // race-free. This doesn't cost us anything on reads in X86. (An
310 // unsynchronized implementation probably is OK too, but I didn't want
311 // to prove that we never read from device_guard_impl_registry at the
312 // same time some registration is occurring. Shiver.)
313 //
314 // I'd like this registry to be valid even at program destruction time
315 // (in case someone uses a DeviceGuard in a destructor to do some cleanup
316 // in the CUDA API.) Since there are no direct accesses of the underlying
317 // owning objects which I can use to enforce initialization order (unlike
318 // in a Meyer singleton), it implies that you must *leak* objects when
319 // putting them in the registry. This is done by deleting the destructor
320 // on DeviceGuardImplInterface.
321 // NOLINTNEXTLINE(*c-arrays*)
322 extern C10_API std::atomic<const DeviceGuardImplInterface*>
323 device_guard_impl_registry[static_cast<size_t>(
324 DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
325
326 // I can't conveniently use c10/util/Registry.h for the following reason:
327 // c10/util/Registry.h gives me a slow way of Create'ing a object of some
328 // interface from the registry, but no way of quickly accessing an already
329 // created object. I'll be banging on getDeviceGuardImpl every time we do a
330 // DeviceGuard, so I really don't want to be doing an unordered_map lookup.
331 // Better if the registration mechanism directly drops its implementation
332 // into device_guard_impl_registry.
333
334 class C10_API DeviceGuardImplRegistrar {
335 public:
336 DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*);
337 };
338
339 #define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
340 static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \
341 g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
342
getDeviceGuardImpl(DeviceType type)343 inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
344 // Two adjacent int16_t fields DeviceType and DeviceIndex has field access
345 // miscompiled on NVCC. To workaround this issue, we apply a mask to the
346 // DeviceType. First check if the DeviceType is 16-bit.
347 // FB employees can see
348 // https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/
349 // for more details
350 static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit");
351 auto p = device_guard_impl_registry[static_cast<size_t>(type) & 0xFF].load();
352
353 // This seems to be the first place where you make use of a device
354 // when you pass devices to factory functions. Give a nicer error
355 // message in this case.
356 TORCH_CHECK(p, "PyTorch is not linked with support for ", type, " devices");
357 return p;
358 }
359
hasDeviceGuardImpl(DeviceType type)360 inline bool hasDeviceGuardImpl(DeviceType type) {
361 return device_guard_impl_registry[static_cast<size_t>(type)].load();
362 }
363
364 } // namespace impl
365 } // namespace c10
366