1 #pragma once 2 3 #include <c10/core/DeviceType.h> 4 #include <c10/macros/Macros.h> 5 #include <c10/util/Array.h> 6 7 #include <atomic> 8 #include <utility> 9 #include <variant> 10 11 // Implements instruction set specific function dispatch. 12 // 13 // Kernels that may make use of specialized instruction sets (e.g. AVX2) are 14 // compiled multiple times with different compiler flags (e.g. -mavx2). A 15 // DispatchStub contains a table of function pointers for a kernel. At runtime, 16 // the fastest available kernel is chosen based on the features reported by 17 // cpuinfo. 18 // 19 // Example: 20 // 21 // In native/MyKernel.h: 22 // using fn_type = void(*)(const Tensor& x); 23 // DECLARE_DISPATCH(fn_type, stub); 24 // 25 // In native/MyKernel.cpp 26 // DEFINE_DISPATCH(stub); 27 // 28 // In native/cpu/MyKernel.cpp: 29 // namespace { 30 // // use anonymous namespace so that different cpu versions won't conflict 31 // void kernel(const Tensor& x) { ... } 32 // } 33 // REGISTER_DISPATCH(stub, &kernel); 34 // 35 // To call: 36 // stub(kCPU, tensor); 37 // 38 // TODO: CPU instruction set selection should be folded into whatever 39 // the main dispatch mechanism is. 40 // 41 // Supported device types for registration: 42 // - CPU: Central Processing Unit 43 // - CUDA: NVIDIA GPUs 44 // - HIP: AMD GPUs 45 // - MPS: Apple Silicon GPUs (Metal Performance Shaders) 46 // - MTIA: Meta Training and Inference Devices 47 // - XPU: Intel GPUs 48 // - PrivateUse1: Reserved for private/custom device types 49 // 50 // If you want to update the list of supported devices, add a new dispatch_ptr 51 // member in DispatchStubImpl.h and update the get_call_ptr switch. 52 // As well you will need to update the inlined list in 'is_device_supported` 53 // 54 // 55 // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere 56 C10_CLANG_DIAGNOSTIC_PUSH() 57 C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template") 58 59 namespace at::native { 60 61 enum class CPUCapability { 62 DEFAULT = 0, 63 #if defined(HAVE_VSX_CPU_DEFINITION) 64 VSX = 1, 65 #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) 66 ZVECTOR = 1, 67 #else 68 AVX2 = 1, 69 AVX512 = 2, 70 #endif 71 NUM_OPTIONS 72 }; 73 74 // Enum for error types 75 enum class ErrorType { 76 MissingDeviceKernel, 77 DeviceNotSupported 78 }; 79 80 // Alias for the return type using std::variant 81 using DispatchResult = std::variant<void*, ErrorType>; 82 83 CPUCapability get_cpu_capability(); 84 85 template <typename FnPtr, typename T> 86 struct DispatchStub; 87 88 /** 89 * The sole purpose of this class is to outline methods that don't need to be 90 * specialized or otherwise inlined and duplicated (by the compiler due to 91 * template expansion), since it causes size bloat if there are a significant 92 * number of specialization of the DispatchStub<> class. 93 */ 94 struct TORCH_API DispatchStubImpl { 95 96 // The DispatchStubImpl::try_get_call_ptr() method is used to get the call 97 // pointer for a given device type. If the call pointer is not found, 98 // DispatchStubImpl::try_get_call_ptr() returns an ErrorType. 99 // The main difference between try_get_call_ptr() and get_call_ptr() is that 100 // try_get_call_ptr() will return the ErrorType and not raise an exception. 101 DispatchResult try_get_call_ptr( 102 c10::DeviceType device_type 103 , void *DEFAULT 104 #ifdef HAVE_AVX512_CPU_DEFINITION 105 , void *AVX512 106 #endif 107 #ifdef HAVE_AVX2_CPU_DEFINITION 108 , void *AVX2 109 #endif 110 #ifdef HAVE_VSX_CPU_DEFINITION 111 , void *VSX 112 #endif 113 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 114 , void *ZVECTOR 115 #endif 116 ); 117 118 // Analogous to try_get_call_ptr(), but it will return the ErrorType and not 119 // raise an exception. 120 DispatchResult try_choose_cpu_impl( 121 void *DEFAULT 122 #ifdef HAVE_AVX512_CPU_DEFINITION 123 , void *AVX512 124 #endif 125 #ifdef HAVE_AVX2_CPU_DEFINITION 126 , void *AVX2 127 #endif 128 #ifdef HAVE_VSX_CPU_DEFINITION 129 , void *VSX 130 #endif 131 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 132 , void *ZVECTOR 133 #endif 134 ); 135 136 137 void* get_call_ptr( 138 c10::DeviceType device_type 139 , void *DEFAULT 140 #ifdef HAVE_AVX512_CPU_DEFINITION 141 , void *AVX512 142 #endif 143 #ifdef HAVE_AVX2_CPU_DEFINITION 144 , void *AVX2 145 #endif 146 #ifdef HAVE_VSX_CPU_DEFINITION 147 , void *VSX 148 #endif 149 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 150 , void *ZVECTOR 151 #endif 152 ); 153 154 /** 155 * The CPU Dispatch actual method is chosen in decreasing order of preference by 156 * DispatchStubImpl::choose_cpu_impl() in case none is found by 157 * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr. 158 */ 159 void* choose_cpu_impl( 160 void *DEFAULT 161 #ifdef HAVE_AVX512_CPU_DEFINITION 162 , void *AVX512 163 #endif 164 #ifdef HAVE_AVX2_CPU_DEFINITION 165 , void *AVX2 166 #endif 167 #ifdef HAVE_VSX_CPU_DEFINITION 168 , void *VSX 169 #endif 170 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 171 , void *ZVECTOR 172 #endif 173 ); 174 175 // Fixing dispatch error in Windows debug builds. 176 // See https://github.com/pytorch/pytorch/issues/22681 for more details. 177 #if defined(_MSC_VER) && defined(_DEBUG) 178 std::atomic<void*> cpu_dispatch_ptr; 179 void* cuda_dispatch_ptr; 180 void* hip_dispatch_ptr; 181 void* mps_dispatch_ptr; 182 void* mtia_dispatch_ptr; 183 #if defined(USE_XPU) 184 void* xpu_dispatch_ptr; 185 #endif 186 void* privateuse1_dispatch_ptr; 187 #else 188 std::atomic<void*> cpu_dispatch_ptr{nullptr}; 189 void* cuda_dispatch_ptr = nullptr; 190 void* hip_dispatch_ptr = nullptr; 191 void* mps_dispatch_ptr = nullptr; 192 void* mtia_dispatch_ptr = nullptr; 193 #if defined(USE_XPU) 194 void* xpu_dispatch_ptr = nullptr; 195 #endif 196 void* privateuse1_dispatch_ptr = nullptr; 197 #endif 198 }; 199 200 template <typename rT, typename T, typename... Args> 201 struct DispatchStub<rT (*)(Args...), T> { 202 using FnPtr = rT (*) (Args...); 203 204 DispatchStub() = default; 205 DispatchStub(const DispatchStub&) = delete; 206 DispatchStub& operator=(const DispatchStub&) = delete; 207 208 private: 209 FnPtr get_call_ptr(const c10::DeviceType device_type) { 210 return reinterpret_cast<FnPtr>( 211 impl.get_call_ptr(device_type 212 , reinterpret_cast<void*>(DEFAULT) 213 #ifdef HAVE_AVX512_CPU_DEFINITION 214 , reinterpret_cast<void*>(AVX512) 215 #endif 216 #ifdef HAVE_AVX2_CPU_DEFINITION 217 , reinterpret_cast<void*>(AVX2) 218 #endif 219 #ifdef HAVE_VSX_CPU_DEFINITION 220 , reinterpret_cast<void*>(VSX) 221 #endif 222 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 223 , reinterpret_cast<void*>(ZVECTOR) 224 #endif 225 ) 226 ); 227 } 228 229 public: 230 template <typename... ArgTypes> 231 rT operator()(c10::DeviceType device_type, ArgTypes&&... args) { 232 FnPtr call_ptr = get_call_ptr(device_type); 233 return (*call_ptr)(std::forward<ArgTypes>(args)...); 234 } 235 236 void set_cuda_dispatch_ptr(FnPtr fn_ptr) { 237 impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); 238 } 239 240 #if defined(USE_XPU) 241 void set_xpu_dispatch_ptr(FnPtr fn_ptr){ 242 impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); 243 } 244 #endif 245 246 void set_hip_dispatch_ptr(FnPtr fn_ptr) { 247 impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); 248 } 249 250 void set_mps_dispatch_ptr(FnPtr fn_ptr) { 251 impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); 252 } 253 254 void set_mtia_dispatch_ptr(FnPtr fn_ptr) { 255 impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); 256 } 257 258 void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) { 259 impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); 260 } 261 262 // Returns true if the dispatcher has a kernel registered for this device 263 // type. 264 bool is_device_supported(const c10::DeviceType device_type) { 265 auto result = impl.try_get_call_ptr(device_type 266 , reinterpret_cast<void*>(DEFAULT) 267 #ifdef HAVE_AVX512_CPU_DEFINITION 268 , reinterpret_cast<void*>(AVX512) 269 #endif 270 #ifdef HAVE_AVX2_CPU_DEFINITION 271 , reinterpret_cast<void*>(AVX2) 272 #endif 273 #ifdef HAVE_VSX_CPU_DEFINITION 274 , reinterpret_cast<void*>(VSX) 275 #endif 276 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 277 , reinterpret_cast<void*>(ZVECTOR) 278 #endif 279 ); 280 if (std::holds_alternative<ErrorType>(result)){ 281 return false; 282 } 283 return true; 284 }; 285 286 static TORCH_API FnPtr DEFAULT; 287 #ifdef HAVE_AVX512_CPU_DEFINITION 288 static TORCH_API FnPtr AVX512; 289 #endif 290 #ifdef HAVE_AVX2_CPU_DEFINITION 291 static TORCH_API FnPtr AVX2; 292 #endif 293 #ifdef HAVE_VSX_CPU_DEFINITION 294 static TORCH_API FnPtr VSX; 295 #endif 296 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 297 static TORCH_API FnPtr ZVECTOR; 298 #endif 299 private: 300 DispatchStubImpl impl; 301 }; 302 303 namespace { 304 template <typename DispatchStub> 305 struct RegisterCUDADispatch { 306 RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { 307 stub.set_cuda_dispatch_ptr(value); 308 } 309 }; 310 311 template <typename DispatchStub> 312 struct RegisterXPUDispatch { 313 RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){ 314 stub.set_xpu_dispatch_ptr(value); 315 } 316 }; 317 318 template <typename DispatchStub> 319 struct RegisterMPSDispatch { 320 RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { 321 stub.set_mps_dispatch_ptr(value); 322 } 323 }; 324 325 template <typename DispatchStub> 326 struct RegisterHIPDispatch { 327 RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { 328 // TODO: make this point at hip_dispatch_ptr 329 stub.set_cuda_dispatch_ptr(value); 330 } 331 }; 332 333 template <typename DispatchStub> 334 struct RegisterMTIADispatch { 335 RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { 336 stub.set_mtia_dispatch_ptr(value); 337 } 338 }; 339 340 template <typename DispatchStub> 341 struct RegisterPRIVATEUSE1Dispatch { 342 RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { 343 stub.set_privateuse1_dispatch_ptr(value); 344 } 345 }; 346 347 } // anonymous namespace 348 // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in 349 // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g., 350 // adding parentheses and using helper struct to get rid of the parentheses, do 351 // not work with MSVC. So do a `using`-declaration if you need to pass in such 352 // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h. 353 #define DECLARE_DISPATCH(fn, name) \ 354 struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \ 355 name##_DECLARE_DISPATCH_type() = default; \ 356 name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \ 357 name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \ 358 }; \ 359 extern TORCH_API struct name##_DECLARE_DISPATCH_type name; 360 361 #define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name 362 363 #define REGISTER_ARCH_DISPATCH(name, arch, fn) \ 364 template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn; 365 366 #ifdef HAVE_AVX512_CPU_DEFINITION 367 #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) 368 #else 369 #define REGISTER_AVX512_DISPATCH(name, fn) 370 #endif 371 372 #ifdef HAVE_AVX2_CPU_DEFINITION 373 #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn) 374 #else 375 #define REGISTER_AVX2_DISPATCH(name, fn) 376 #endif 377 378 #ifdef HAVE_VSX_CPU_DEFINITION 379 #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) 380 #else 381 #define REGISTER_VSX_DISPATCH(name, fn) 382 #endif 383 384 #ifdef HAVE_ZVECTOR_CPU_DEFINITION 385 #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn) 386 #else 387 #define REGISTER_ZVECTOR_DISPATCH(name, fn) 388 #endif 389 390 // Macro to register the same kernel for all CPU arch types. This is useful 391 // if a kernel does not benefit from being recompiled across different arch types. 392 #define REGISTER_ALL_CPU_DISPATCH(name, fn) \ 393 REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \ 394 REGISTER_AVX512_DISPATCH(name, fn) \ 395 REGISTER_AVX2_DISPATCH(name, fn) \ 396 REGISTER_VSX_DISPATCH(name, fn) \ 397 REGISTER_ZVECTOR_DISPATCH(name, fn) 398 399 #define REGISTER_NO_CPU_DISPATCH(name) \ 400 REGISTER_ALL_CPU_DISPATCH(name, nullptr) 401 402 #define REGISTER_CUDA_DISPATCH(name, fn) \ 403 static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn); 404 405 #define REGISTER_XPU_DISPATCH(name, fn) \ 406 static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn); 407 408 #define REGISTER_HIP_DISPATCH(name, fn) \ 409 static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn); 410 411 #define REGISTER_MPS_DISPATCH(name, fn) \ 412 static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn); 413 414 #define REGISTER_MTIA_DISPATCH(name, fn) \ 415 static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn); 416 417 #define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \ 418 static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn); 419 420 // NB: This macro must be used in an actual 'cu' file; if you try using 421 // it from a 'cpp' file it will not work! 422 #if defined(__CUDACC__) 423 #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) 424 #elif defined(__HIPCC__) 425 // TODO: cut this over to HIP dispatch once we stop pretending that CUDA 426 // is HIP in the PyTorch HIPify build. 427 #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) 428 // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) 429 #elif defined(__OBJC__) && defined(USE_MPS) 430 // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel 431 #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn) 432 #elif defined(CPU_CAPABILITY) 433 // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. 434 // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. 435 #ifdef CPU_CAPABILITY_AVX512 436 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr)) 437 #else 438 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) 439 #endif 440 #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) 441 #endif 442 } // namespace at::native 443 444 C10_CLANG_DIAGNOSTIC_POP() 445