xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/DispatchStub.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Array.h>
6 
7 #include <atomic>
8 #include <utility>
9 #include <variant>
10 
11 // Implements instruction set specific function dispatch.
12 //
13 // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
14 // compiled multiple times with different compiler flags (e.g. -mavx2). A
15 // DispatchStub contains a table of function pointers for a kernel. At runtime,
16 // the fastest available kernel is chosen based on the features reported by
17 // cpuinfo.
18 //
19 // Example:
20 //
21 // In native/MyKernel.h:
22 //   using fn_type = void(*)(const Tensor& x);
23 //   DECLARE_DISPATCH(fn_type, stub);
24 //
25 // In native/MyKernel.cpp
26 //   DEFINE_DISPATCH(stub);
27 //
28 // In native/cpu/MyKernel.cpp:
29 //   namespace {
30 //     // use anonymous namespace so that different cpu versions won't conflict
31 //     void kernel(const Tensor& x) { ... }
32 //   }
33 //   REGISTER_DISPATCH(stub, &kernel);
34 //
35 // To call:
36 //   stub(kCPU, tensor);
37 //
38 // TODO: CPU instruction set selection should be folded into whatever
39 // the main dispatch mechanism is.
40 //
41 // Supported device types for registration:
42 //   - CPU: Central Processing Unit
43 //   - CUDA: NVIDIA GPUs
44 //   - HIP: AMD GPUs
45 //   - MPS: Apple Silicon GPUs (Metal Performance Shaders)
46 //   - MTIA: Meta Training and Inference Devices
47 //   - XPU: Intel GPUs
48 //   - PrivateUse1: Reserved for private/custom device types
49 //
50 // If you want to update the list of supported devices, add a new dispatch_ptr
51 // member in DispatchStubImpl.h and update the get_call_ptr switch.
52 // As well you will need to update the inlined list in 'is_device_supported`
53 //
54 //
55 // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
56 C10_CLANG_DIAGNOSTIC_PUSH()
57 C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
58 
59 namespace at::native {
60 
61 enum class CPUCapability {
62   DEFAULT = 0,
63 #if defined(HAVE_VSX_CPU_DEFINITION)
64   VSX = 1,
65 #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
66   ZVECTOR = 1,
67 #else
68   AVX2 = 1,
69   AVX512 = 2,
70 #endif
71   NUM_OPTIONS
72 };
73 
74 // Enum for error types
75 enum class ErrorType {
76   MissingDeviceKernel,
77   DeviceNotSupported
78 };
79 
80 // Alias for the return type using std::variant
81 using DispatchResult = std::variant<void*, ErrorType>;
82 
83 CPUCapability get_cpu_capability();
84 
85 template <typename FnPtr, typename T>
86 struct DispatchStub;
87 
88 /**
89  * The sole purpose of this class is to outline methods that don't need to be
90  * specialized or otherwise inlined and duplicated (by the compiler due to
91  * template expansion), since it causes size bloat if there are a significant
92  * number of specialization of the DispatchStub<> class.
93  */
94 struct TORCH_API DispatchStubImpl {
95 
96   // The DispatchStubImpl::try_get_call_ptr() method is used to get the call
97   // pointer for a given device type. If the call pointer is not found,
98   // DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
99   // The main difference between try_get_call_ptr() and get_call_ptr() is that
100   // try_get_call_ptr() will return the ErrorType and not raise an exception.
101   DispatchResult try_get_call_ptr(
102     c10::DeviceType device_type
103     , void *DEFAULT
104 #ifdef HAVE_AVX512_CPU_DEFINITION
105       , void *AVX512
106 #endif
107 #ifdef HAVE_AVX2_CPU_DEFINITION
108       , void *AVX2
109 #endif
110 #ifdef HAVE_VSX_CPU_DEFINITION
111       , void *VSX
112 #endif
113 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
114       , void *ZVECTOR
115 #endif
116   );
117 
118   // Analogous to try_get_call_ptr(), but it will return the ErrorType and not
119   // raise an exception.
120   DispatchResult try_choose_cpu_impl(
121     void *DEFAULT
122 #ifdef HAVE_AVX512_CPU_DEFINITION
123     , void *AVX512
124 #endif
125 #ifdef HAVE_AVX2_CPU_DEFINITION
126     , void *AVX2
127 #endif
128 #ifdef HAVE_VSX_CPU_DEFINITION
129     , void *VSX
130 #endif
131 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
132     , void *ZVECTOR
133 #endif
134   );
135 
136 
137   void* get_call_ptr(
138     c10::DeviceType device_type
139     , void *DEFAULT
140 #ifdef HAVE_AVX512_CPU_DEFINITION
141       , void *AVX512
142 #endif
143 #ifdef HAVE_AVX2_CPU_DEFINITION
144       , void *AVX2
145 #endif
146 #ifdef HAVE_VSX_CPU_DEFINITION
147       , void *VSX
148 #endif
149 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
150       , void *ZVECTOR
151 #endif
152   );
153 
154   /**
155    * The CPU Dispatch actual method is chosen in decreasing order of preference by
156    * DispatchStubImpl::choose_cpu_impl() in case none is found by
157    * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
158    */
159   void* choose_cpu_impl(
160     void *DEFAULT
161 #ifdef HAVE_AVX512_CPU_DEFINITION
162     , void *AVX512
163 #endif
164 #ifdef HAVE_AVX2_CPU_DEFINITION
165     , void *AVX2
166 #endif
167 #ifdef HAVE_VSX_CPU_DEFINITION
168     , void *VSX
169 #endif
170 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
171     , void *ZVECTOR
172 #endif
173   );
174 
175   // Fixing dispatch error in Windows debug builds.
176   // See https://github.com/pytorch/pytorch/issues/22681 for more details.
177   #if defined(_MSC_VER) && defined(_DEBUG)
178     std::atomic<void*> cpu_dispatch_ptr;
179     void* cuda_dispatch_ptr;
180     void* hip_dispatch_ptr;
181     void* mps_dispatch_ptr;
182     void* mtia_dispatch_ptr;
183   #if defined(USE_XPU)
184     void* xpu_dispatch_ptr;
185   #endif
186     void* privateuse1_dispatch_ptr;
187   #else
188     std::atomic<void*> cpu_dispatch_ptr{nullptr};
189     void* cuda_dispatch_ptr = nullptr;
190     void* hip_dispatch_ptr = nullptr;
191     void* mps_dispatch_ptr = nullptr;
192     void* mtia_dispatch_ptr = nullptr;
193   #if defined(USE_XPU)
194     void* xpu_dispatch_ptr = nullptr;
195   #endif
196     void* privateuse1_dispatch_ptr = nullptr;
197   #endif
198 };
199 
200 template <typename rT, typename T, typename... Args>
201 struct DispatchStub<rT (*)(Args...), T> {
202   using FnPtr = rT (*) (Args...);
203 
204   DispatchStub() = default;
205   DispatchStub(const DispatchStub&) = delete;
206   DispatchStub& operator=(const DispatchStub&) = delete;
207 
208 private:
209   FnPtr get_call_ptr(const c10::DeviceType device_type) {
210     return reinterpret_cast<FnPtr>(
211       impl.get_call_ptr(device_type
212       , reinterpret_cast<void*>(DEFAULT)
213 #ifdef HAVE_AVX512_CPU_DEFINITION
214       , reinterpret_cast<void*>(AVX512)
215 #endif
216 #ifdef HAVE_AVX2_CPU_DEFINITION
217       , reinterpret_cast<void*>(AVX2)
218 #endif
219 #ifdef HAVE_VSX_CPU_DEFINITION
220       , reinterpret_cast<void*>(VSX)
221 #endif
222 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
223       , reinterpret_cast<void*>(ZVECTOR)
224 #endif
225       )
226     );
227   }
228 
229 public:
230   template <typename... ArgTypes>
231   rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
232     FnPtr call_ptr = get_call_ptr(device_type);
233     return (*call_ptr)(std::forward<ArgTypes>(args)...);
234   }
235 
236   void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
237     impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
238   }
239 
240   #if defined(USE_XPU)
241   void set_xpu_dispatch_ptr(FnPtr fn_ptr){
242     impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
243   }
244   #endif
245 
246   void set_hip_dispatch_ptr(FnPtr fn_ptr) {
247     impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
248   }
249 
250   void set_mps_dispatch_ptr(FnPtr fn_ptr) {
251     impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
252   }
253 
254     void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
255     impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
256   }
257 
258   void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
259     impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
260   }
261 
262   // Returns true if the dispatcher has a kernel registered for this device
263   // type.
264   bool is_device_supported(const c10::DeviceType device_type) {
265     auto result = impl.try_get_call_ptr(device_type
266       , reinterpret_cast<void*>(DEFAULT)
267 #ifdef HAVE_AVX512_CPU_DEFINITION
268       , reinterpret_cast<void*>(AVX512)
269 #endif
270 #ifdef HAVE_AVX2_CPU_DEFINITION
271       , reinterpret_cast<void*>(AVX2)
272 #endif
273 #ifdef HAVE_VSX_CPU_DEFINITION
274       , reinterpret_cast<void*>(VSX)
275 #endif
276 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
277       , reinterpret_cast<void*>(ZVECTOR)
278 #endif
279       );
280     if (std::holds_alternative<ErrorType>(result)){
281       return false;
282     }
283     return true;
284   };
285 
286   static TORCH_API FnPtr DEFAULT;
287 #ifdef HAVE_AVX512_CPU_DEFINITION
288   static TORCH_API FnPtr AVX512;
289 #endif
290 #ifdef HAVE_AVX2_CPU_DEFINITION
291   static TORCH_API FnPtr AVX2;
292 #endif
293 #ifdef HAVE_VSX_CPU_DEFINITION
294   static TORCH_API FnPtr VSX;
295 #endif
296 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
297   static TORCH_API FnPtr ZVECTOR;
298 #endif
299 private:
300   DispatchStubImpl impl;
301 };
302 
303 namespace {
304 template <typename DispatchStub>
305 struct RegisterCUDADispatch {
306   RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
307     stub.set_cuda_dispatch_ptr(value);
308   }
309 };
310 
311 template <typename DispatchStub>
312 struct RegisterXPUDispatch {
313   RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
314     stub.set_xpu_dispatch_ptr(value);
315   }
316 };
317 
318 template <typename DispatchStub>
319 struct RegisterMPSDispatch {
320   RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
321     stub.set_mps_dispatch_ptr(value);
322   }
323 };
324 
325 template <typename DispatchStub>
326 struct RegisterHIPDispatch {
327   RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
328     // TODO: make this point at hip_dispatch_ptr
329     stub.set_cuda_dispatch_ptr(value);
330   }
331 };
332 
333 template <typename DispatchStub>
334 struct RegisterMTIADispatch {
335   RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
336     stub.set_mtia_dispatch_ptr(value);
337   }
338 };
339 
340 template <typename DispatchStub>
341 struct RegisterPRIVATEUSE1Dispatch {
342   RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
343     stub.set_privateuse1_dispatch_ptr(value);
344   }
345 };
346 
347 } // anonymous namespace
348 // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
349 // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
350 // adding parentheses and using helper struct to get rid of the parentheses, do
351 // not work with MSVC. So do a `using`-declaration if you need to pass in such
352 // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
353 #define DECLARE_DISPATCH(fn, name)                                                         \
354   struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> {   \
355     name##_DECLARE_DISPATCH_type() = default;                                              \
356     name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete;            \
357     name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
358   };                                                                                       \
359   extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
360 
361 #define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
362 
363 #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
364   template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
365 
366 #ifdef HAVE_AVX512_CPU_DEFINITION
367 #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
368 #else
369 #define REGISTER_AVX512_DISPATCH(name, fn)
370 #endif
371 
372 #ifdef HAVE_AVX2_CPU_DEFINITION
373 #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
374 #else
375 #define REGISTER_AVX2_DISPATCH(name, fn)
376 #endif
377 
378 #ifdef HAVE_VSX_CPU_DEFINITION
379 #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
380 #else
381 #define REGISTER_VSX_DISPATCH(name, fn)
382 #endif
383 
384 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
385 #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
386 #else
387 #define REGISTER_ZVECTOR_DISPATCH(name, fn)
388 #endif
389 
390 // Macro to register the same kernel for all CPU arch types. This is useful
391 // if a kernel does not benefit from being recompiled across different arch types.
392 #define REGISTER_ALL_CPU_DISPATCH(name, fn)                                    \
393   REGISTER_ARCH_DISPATCH(name, DEFAULT, fn)                                    \
394   REGISTER_AVX512_DISPATCH(name, fn)                                           \
395   REGISTER_AVX2_DISPATCH(name, fn)                                             \
396   REGISTER_VSX_DISPATCH(name, fn)                                              \
397   REGISTER_ZVECTOR_DISPATCH(name, fn)
398 
399 #define REGISTER_NO_CPU_DISPATCH(name)                                         \
400   REGISTER_ALL_CPU_DISPATCH(name, nullptr)
401 
402 #define REGISTER_CUDA_DISPATCH(name, fn) \
403   static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
404 
405 #define REGISTER_XPU_DISPATCH(name, fn) \
406   static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
407 
408 #define REGISTER_HIP_DISPATCH(name, fn) \
409   static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
410 
411 #define REGISTER_MPS_DISPATCH(name, fn) \
412   static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
413 
414 #define REGISTER_MTIA_DISPATCH(name, fn) \
415   static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
416 
417 #define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
418   static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
419 
420 // NB: This macro must be used in an actual 'cu' file; if you try using
421 // it from a 'cpp' file it will not work!
422 #if defined(__CUDACC__)
423 #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
424 #elif defined(__HIPCC__)
425 // TODO: cut this over to HIP dispatch once we stop pretending that CUDA
426 // is HIP in the PyTorch HIPify build.
427 #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
428 // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
429 #elif defined(__OBJC__) && defined(USE_MPS)
430 // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
431 #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
432 #elif defined(CPU_CAPABILITY)
433 // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
434 // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
435 #ifdef CPU_CAPABILITY_AVX512
436 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
437 #else
438 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
439 #endif
440 #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
441 #endif
442 } // namespace at::native
443 
444 C10_CLANG_DIAGNOSTIC_POP()
445