xref: /aosp_15_r20/external/pytorch/c10/core/impl/DeviceGuardImplInterface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/DeviceType.h>
5 #include <c10/core/Stream.h>
6 #include <c10/util/Exception.h>
7 
8 // Just for C10_ANONYMOUS_VARIABLE
9 #include <c10/util/Registry.h>
10 
11 #include <atomic>
12 
13 namespace c10 {
14 
15 // Forward declaration
16 class DataPtr;
17 
18 /**
19  * Note [Flags defining the behavior of events]
20  *
21  * PYTORCH_DEFAULT and BACKEND_DEFAULT are valid for all backends. The
22  * BACKEND_DEFAULT is what a particular backend would select if no
23  * flags were given. PYTORCH_DEFAULT is the PyTorch's framework default
24  * choice for events on that backend, which may not be the same.
25  *
26  * The mapping of PYTORCH_DEFAULT and BACKEND_DEFAULT is done by each
27  * backend implementation.
28  */
29 enum class EventFlag {
30   // Disable timing
31   PYTORCH_DEFAULT,
32   // Enable timing
33   BACKEND_DEFAULT,
34   // FOR TESTING ONLY
35   INVALID
36 };
37 
38 namespace impl {
39 
40 /**
41  * DeviceGuardImplInterface represents the virtual interface which provides
42  * functionality to provide an RAII class for device and stream switching,
43  * via DeviceGuard.  Every distinct device type, e.g., CUDA and HIP, is
44  * expected to implement and register an implementation of this interface.
45  * All classes which inherit from DeviceGuardImplInterface should be declared
46  * 'final'.
47  *
48  * This class exists because we provide a unified interface for performing
49  * device guards via DeviceGuard, but we cannot assume that we have actually
50  * compiled against the, e.g., CUDA library, which actually implements
51  * this guard functionality.  In this case, a dynamic dispatch is required
52  * to cross the library boundary.
53  *
54  * If possible, you should directly use implementations of this interface;
55  * those uses will be devirtualized.
56  */
57 struct C10_API DeviceGuardImplInterface {
58   DeviceGuardImplInterface() = default;
59   DeviceGuardImplInterface(const DeviceGuardImplInterface&) = default;
60   DeviceGuardImplInterface& operator=(const DeviceGuardImplInterface&) =
61       default;
62   DeviceGuardImplInterface(DeviceGuardImplInterface&&) noexcept = default;
63   DeviceGuardImplInterface& operator=(DeviceGuardImplInterface&&) noexcept =
64       default;
65 
66   /**
67    * Return the type of device managed by this guard implementation.
68    */
69   virtual DeviceType type() const = 0;
70 
71   /**
72    * Set the current device to Device, and return the previous Device.
73    */
74   virtual Device exchangeDevice(Device) const = 0;
75   // NB: Implementations of exchangeDevice can be a bit boilerplatey.  You might
76   // consider replacing exchangeDevice with a non-virtual function with a baked
77   // in implementation; however, note that this will triple the number of
78   // virtual calls (when you implement exchangeDevice in a final subclass,
79   // the compiler gets to devirtualize everything; it won't do that if you don't
80   // define it in the subclass!)  A common way to solve this problem is to use
81   // some sort of CRTP; however, we can template DeviceGuardImplInterface since
82   // we really *do* need it to be virtual.  A little boilerplate seems easiest
83   // to explain.  (Another way around this problem is to provide inline
84   // functions that provide the default implementations, but this seems a little
85   // hard to explain.  In any case, we're only going to have on order of ten
86   // implementations of this anyway.)
87 
88   /**
89    * Get the current device.
90    */
91   virtual Device getDevice() const = 0;
92 
93   /**
94    * Set the current device to Device.
95    */
96   virtual void setDevice(Device) const = 0;
97 
98   /**
99    * Set the current device to Device, without checking for errors
100    * (so, e.g., this can be called from a destructor).
101    */
102   virtual void uncheckedSetDevice(Device) const noexcept = 0;
103 
104   /**
105    * Get the current stream for a given device.
106    */
107   virtual Stream getStream(Device) const noexcept = 0;
108 
109   /**
110    * Get the default stream for a given device.
111    */
getDefaultStreamDeviceGuardImplInterface112   virtual Stream getDefaultStream(Device) const {
113     TORCH_CHECK(false, "Backend doesn't support acquiring a default stream.")
114   }
115 
116   /**
117    * Get a stream from the global pool for a given device.
118    */
119   virtual Stream getStreamFromGlobalPool(Device, bool isHighPriority = false)
120       const {
121     (void)isHighPriority; // Suppress unused variable warning
122     TORCH_CHECK(false, "Backend doesn't support acquiring a stream from pool.")
123   }
124 
125   /**
126    * Return a new stream for a given device and priority. The stream will be
127    * copied and shared around, device backend should be able to correctly handle
128    * the lifetime of the stream.
129    */
130   virtual Stream getNewStream(Device, int priority = 0) const {
131     (void)priority;
132     TORCH_CHECK(false, "Backend doesn't support create a new Stream.")
133   }
134 
135   /**
136    * Set a stream to be the thread local current stream for its device.
137    * Return the previous stream for that device. You are NOT required
138    * to set the current device to match the device of this stream.
139    */
140   virtual Stream exchangeStream(Stream) const noexcept = 0;
141 
142   /**
143    * Destroys the given event.
144    */
destroyEventDeviceGuardImplInterface145   virtual void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
146       const noexcept {}
147 
148   /**
149    * Increments the event's version and enqueues a job with this version
150    * in the stream's work queue. When the stream process that job
151    * it notifies all streams waiting on / blocked by that version of the
152    * event to continue and marks that version as recorded.
153    * */
recordDeviceGuardImplInterface154   virtual void record(
155       void** /*event*/,
156       const Stream& /*stream*/,
157       const DeviceIndex /*device_index*/,
158       const c10::EventFlag /*flag*/) const {
159     TORCH_CHECK(false, "Backend doesn't support events.");
160   }
161 
162   /**
163    * Does nothing if the event has not been scheduled to be recorded.
164    * If the event was previously enqueued to be recorded, a command
165    * to wait for the version of the event that exists at the time of this call
166    * is inserted in the stream's work queue.
167    * When the stream reaches this command it will stop processing
168    * additional commands until that version of the event is marked as recorded.
169    */
blockDeviceGuardImplInterface170   virtual void block(void* /*event*/, const Stream& /*stream*/) const {
171     TORCH_CHECK(false, "Backend doesn't support events.");
172   }
173 
174   /**
175    * Returns true if (and only if)
176    *  (1) the event has never been scheduled to be recorded
177    *  (2) the current version is marked as recorded.
178    * Returns false otherwise.
179    */
queryEventDeviceGuardImplInterface180   virtual bool queryEvent(void* /*event*/) const {
181     TORCH_CHECK(false, "Backend doesn't support events.");
182   }
183 
184   /**
185    * Get the number of devices.  WARNING: This is REQUIRED to not raise
186    * an exception.  If there is some sort of problem, e.g., driver error,
187    * you should report that there are zero available devices.
188    */
189   virtual DeviceIndex deviceCount() const noexcept = 0;
190 
191   /**
192    * Return true if all the work previously enqueued on the stream for
193    * asynchronous execution has completed running on the device.
194    */
queryStreamDeviceGuardImplInterface195   virtual bool queryStream(const Stream& /*stream*/) const {
196     TORCH_CHECK(false, "Backend doesn't support querying streams.");
197   }
198 
199   /**
200    * Wait (by blocking the calling thread) until all the work previously
201    * enqueued on the stream has completed running on the device.
202    */
synchronizeStreamDeviceGuardImplInterface203   virtual void synchronizeStream(const Stream& /*stream*/) const {
204     TORCH_CHECK(false, "Backend doesn't support synchronizing streams.");
205   }
206 
207   /**
208    * Wait (by blocking the calling thread) until all the work previously
209    * recorded on the event has completed running on the device.
210    */
synchronizeEventDeviceGuardImplInterface211   virtual void synchronizeEvent(void* /*event*/) const {
212     TORCH_CHECK(false, "Backend doesn't support synchronizing events.");
213   }
214 
215   /**
216    * Ensure the caching allocator (if any) is aware that the given DataPtr is
217    * being used on the given stream, and that it should thus avoid recycling the
218    * DataPtr until all work on that stream is done.
219    */
recordDataPtrOnStreamDeviceGuardImplInterface220   virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const {
221   }
222 
223   /**
224    * Fetch the elapsed time between two recorded events.
225    */
elapsedTimeDeviceGuardImplInterface226   virtual double elapsedTime(
227       void* /*event1*/,
228       void* /*event2*/,
229       const DeviceIndex /*device_index*/) const {
230     TORCH_CHECK(false, "Backend doesn't support elapsedTime.");
231   }
232 
233   /**
234    * Intended use of this class is to leak the DeviceGuardImpl at program end.
235    * So you better not call the destructor, buster!
236    */
237   virtual ~DeviceGuardImplInterface() = default;
238 };
239 
240 // A no-op device guard impl that doesn't do anything interesting.  Useful
241 // for devices that don't actually have a concept of device index.  Prominent
242 // examples are CPU and Meta.
243 template <DeviceType D>
244 struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
245   NoOpDeviceGuardImpl() = default;
typefinal246   DeviceType type() const override {
247     return D;
248   }
exchangeDevicefinal249   Device exchangeDevice(Device) const override {
250     return Device(D, -1); // no-op
251   }
getDevicefinal252   Device getDevice() const override {
253     return Device(D, -1);
254   }
setDevicefinal255   void setDevice(Device) const override {
256     // no-op
257   }
uncheckedSetDevicefinal258   void uncheckedSetDevice(Device) const noexcept override {
259     // no-op
260   }
getStreamfinal261   Stream getStream(Device) const noexcept override {
262     // no-op
263     return Stream(Stream::DEFAULT, Device(D, -1));
264   }
265 
266   Stream getNewStream(Device, int priority = 0) const override {
267     // no-op
268     (void)priority;
269     return Stream(Stream::DEFAULT, Device(D, -1));
270   }
271 
272   // NB: These do NOT set the current device
exchangeStreamfinal273   Stream exchangeStream(Stream) const noexcept override {
274     // no-op
275     return Stream(Stream::DEFAULT, Device(D, -1));
276   }
deviceCountfinal277   DeviceIndex deviceCount() const noexcept override {
278     return 1;
279   }
280 
281   // Event-related functions
recordfinal282   void record(
283       void** /*event*/,
284       const Stream& /*stream*/,
285       const DeviceIndex /*device_index*/,
286       const EventFlag /*flag*/) const override {
287     TORCH_CHECK(false, D, " backend doesn't support events.");
288   }
blockfinal289   void block(void* /*event*/, const Stream& /*stream*/) const override {
290     TORCH_CHECK(false, D, " backend doesn't support events.")
291   }
queryEventfinal292   bool queryEvent(void* /*event*/) const override {
293     TORCH_CHECK(false, D, " backend doesn't support events.")
294   }
destroyEventfinal295   void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
296       const noexcept override {}
297 
298   // Stream-related functions
queryStreamfinal299   bool queryStream(const Stream& /*stream*/) const override {
300     return true;
301   }
synchronizeStreamfinal302   void synchronizeStream(const Stream& /*stream*/) const override {
303     // Don't wait for anything.
304   }
305 };
306 
307 // The registry is NON-owning.  Each stored pointer is std::atomic so
308 // that under all interleavings of registry calls the structure is
309 // race-free.  This doesn't cost us anything on reads in X86.  (An
310 // unsynchronized implementation probably is OK too, but I didn't want
311 // to prove that we never read from device_guard_impl_registry at the
312 // same time some registration is occurring.  Shiver.)
313 //
314 // I'd like this registry to be valid even at program destruction time
315 // (in case someone uses a DeviceGuard in a destructor to do some cleanup
316 // in the CUDA API.)  Since there are no direct accesses of the underlying
317 // owning objects which I can use to enforce initialization order (unlike
318 // in a Meyer singleton), it implies that you must *leak* objects when
319 // putting them in the registry.  This is done by deleting the destructor
320 // on DeviceGuardImplInterface.
321 // NOLINTNEXTLINE(*c-arrays*)
322 extern C10_API std::atomic<const DeviceGuardImplInterface*>
323     device_guard_impl_registry[static_cast<size_t>(
324         DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
325 
326 // I can't conveniently use c10/util/Registry.h for the following reason:
327 // c10/util/Registry.h gives me a slow way of Create'ing a object of some
328 // interface from the registry, but no way of quickly accessing an already
329 // created object.  I'll be banging on getDeviceGuardImpl every time we do a
330 // DeviceGuard, so I really don't want to be doing an unordered_map lookup.
331 // Better if the registration mechanism directly drops its implementation
332 // into device_guard_impl_registry.
333 
334 class C10_API DeviceGuardImplRegistrar {
335  public:
336   DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*);
337 };
338 
339 #define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl)              \
340   static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \
341       g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
342 
getDeviceGuardImpl(DeviceType type)343 inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
344   // Two adjacent int16_t fields DeviceType and DeviceIndex has field access
345   // miscompiled on NVCC. To workaround this issue, we apply a mask to the
346   // DeviceType. First check if the DeviceType is 16-bit.
347   // FB employees can see
348   //   https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/
349   // for more details
350   static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit");
351   auto p = device_guard_impl_registry[static_cast<size_t>(type) & 0xFF].load();
352 
353   // This seems to be the first place where you make use of a device
354   // when you pass devices to factory functions.  Give a nicer error
355   // message in this case.
356   TORCH_CHECK(p, "PyTorch is not linked with support for ", type, " devices");
357   return p;
358 }
359 
hasDeviceGuardImpl(DeviceType type)360 inline bool hasDeviceGuardImpl(DeviceType type) {
361   return device_guard_impl_registry[static_cast<size_t>(type)].load();
362 }
363 
364 } // namespace impl
365 } // namespace c10
366