xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/DispatchStub.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/DispatchStub.h>
3 
4 #include <c10/core/DeviceType.h>
5 #include <c10/util/Exception.h>
6 
7 #if !defined(__s390x__) && !defined(__powerpc__)
8 #include <cpuinfo.h>
9 #endif
10 #include <cstdlib>
11 #include <cstring>
12 
13 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
14 #include <sys/auxv.h>
15 #endif
16 
17 namespace at::native {
18 
19 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
cpu_has_vxe()20 static inline bool cpu_has_vxe()
21 {
22   return (getauxval(AT_HWCAP) & HWCAP_S390_VXE);
23 }
24 #endif
25 
compute_cpu_capability()26 static CPUCapability compute_cpu_capability() {
27   auto envar = std::getenv("ATEN_CPU_CAPABILITY");
28   if (envar) {
29 #if defined(HAVE_VSX_CPU_DEFINITION)
30     if (strcmp(envar, "vsx") == 0) {
31       return CPUCapability::VSX;
32     }
33 #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
34     if (strcmp(envar, "zvector") == 0) {
35       return CPUCapability::ZVECTOR;
36     }
37 #else
38 #ifdef HAVE_AVX512_CPU_DEFINITION
39     if (strcmp(envar, "avx512") == 0) {
40       return CPUCapability::AVX512;
41     }
42 #endif
43 #ifdef HAVE_AVX2_CPU_DEFINITION
44     if (strcmp(envar, "avx2") == 0) {
45       return CPUCapability::AVX2;
46     }
47 #endif
48 #endif
49     if (strcmp(envar, "default") == 0) {
50       return CPUCapability::DEFAULT;
51     }
52     TORCH_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar);
53   }
54 
55 #if !defined(__powerpc__) && !defined(__s390x__)
56   if (cpuinfo_initialize()) {
57 #if defined(HAVE_AVX512_CPU_DEFINITION)
58     // GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in
59     // versions 9 & beyond. So, we want to ensure that only releases built with
60     // supported compilers on supported hardware return CPU Capability AVX512,
61     // if it's supported on the hardware PyTorch is running on.
62     if (cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() &&  \
63         cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_fma3()) {
64       return CPUCapability::AVX512;
65     }
66 #endif
67 #ifdef HAVE_AVX2_CPU_DEFINITION
68     if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
69       return CPUCapability::AVX2;
70     }
71 #endif
72   }
73 #endif
74 
75 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
76   // vxe is needed for fp32 vector instructions
77   if (cpu_has_vxe()) {
78     return CPUCapability::ZVECTOR;
79   }
80 #endif
81 
82 #ifdef HAVE_VSX_CPU_DEFINITION
83   return CPUCapability::VSX;
84 #else
85   return CPUCapability::DEFAULT;
86 #endif
87 }
88 
get_cpu_capability()89 CPUCapability get_cpu_capability() {
90   static CPUCapability capability = compute_cpu_capability();
91   return capability;
92 }
93 
try_get_call_ptr(const DeviceType device_type,void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)94 DispatchResult DispatchStubImpl::try_get_call_ptr(
95   const DeviceType device_type
96   , void *DEFAULT
97 #ifdef HAVE_AVX512_CPU_DEFINITION
98   , void *AVX512
99 #endif
100 #ifdef HAVE_AVX2_CPU_DEFINITION
101   , void *AVX2
102 #endif
103 #ifdef HAVE_VSX_CPU_DEFINITION
104   , void *VSX
105 #endif
106 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
107   , void *ZVECTOR
108 #endif
109 ) {
110   constexpr auto supported_devices = c10::array_of<c10::DeviceType>(
111         c10::DeviceType::CPU,
112         c10::DeviceType::CUDA,
113         c10::DeviceType::HIP,
114         c10::DeviceType::MPS,
115         c10::DeviceType::MTIA,
116         c10::DeviceType::XPU,
117         c10::DeviceType::PrivateUse1
118     );
119     // Check if the device type is supported.
120     if (std::find(supported_devices.begin(), supported_devices.end(), device_type) == supported_devices.end()) {
121         return ErrorType::DeviceNotSupported;
122     }
123   switch (device_type) {
124     case DeviceType::CPU: {
125       // Use memory_order_relaxed here since even if two threads race,
126       // they will still compute the same value for cpu_dispatch_ptr.
127       auto fptr = cpu_dispatch_ptr.load(std::memory_order_relaxed);
128       if (!fptr) {
129         auto result = try_choose_cpu_impl(
130           DEFAULT
131 #ifdef HAVE_AVX512_CPU_DEFINITION
132           , AVX512
133 #endif
134 #ifdef HAVE_AVX2_CPU_DEFINITION
135           , AVX2
136 #endif
137 #ifdef HAVE_VSX_CPU_DEFINITION
138           , VSX
139 #endif
140 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
141           , ZVECTOR
142 #endif
143         );
144         if (!std::holds_alternative<ErrorType>(result)) {
145           cpu_dispatch_ptr.store(fptr, std::memory_order_relaxed);
146         }
147       return result;
148       }
149       return DispatchResult(fptr);
150     }
151 
152     case DeviceType::CUDA:
153       return cuda_dispatch_ptr != nullptr ? DispatchResult(cuda_dispatch_ptr) : ErrorType::MissingDeviceKernel;
154 
155     case DeviceType::HIP:
156       return hip_dispatch_ptr != nullptr ? DispatchResult(hip_dispatch_ptr) : ErrorType::MissingDeviceKernel;
157 
158 #if defined(USE_MPS)
159     case DeviceType::MPS:
160       return mps_dispatch_ptr != nullptr ? DispatchResult(mps_dispatch_ptr) : ErrorType::MissingDeviceKernel;
161 #endif
162     case DeviceType::MTIA:
163       return mtia_dispatch_ptr != nullptr ? DispatchResult(mtia_dispatch_ptr) : ErrorType::MissingDeviceKernel;
164 
165 #if defined(USE_XPU)
166     case DeviceType::XPU:
167       return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
168 #endif
169 
170     case DeviceType::PrivateUse1:
171       return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;
172 
173     default:
174       TORCH_INTERNAL_ASSERT(false, "An unexpected device type was provided ", device_type);
175       return ErrorType::DeviceNotSupported;
176   }
177 }
178 
get_call_ptr(const DeviceType device_type,void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)179 void* DispatchStubImpl::get_call_ptr(
180   const DeviceType device_type
181   , void *DEFAULT
182 #ifdef HAVE_AVX512_CPU_DEFINITION
183   , void *AVX512
184 #endif
185 #ifdef HAVE_AVX2_CPU_DEFINITION
186   , void *AVX2
187 #endif
188 #ifdef HAVE_VSX_CPU_DEFINITION
189   , void *VSX
190 #endif
191 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
192   , void *ZVECTOR
193 #endif
194 ) {
195 
196   auto result = try_get_call_ptr(
197       device_type,
198       DEFAULT
199 #ifdef HAVE_AVX512_CPU_DEFINITION
200       ,
201       AVX512
202 #endif
203 #ifdef HAVE_AVX2_CPU_DEFINITION
204       ,
205       AVX2
206 #endif
207 #ifdef HAVE_VSX_CPU_DEFINITION
208       ,
209       VSX
210 #endif
211 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
212       ,
213       ZVECTOR
214 #endif
215   );
216   if (std::holds_alternative<ErrorType>(result)) {
217     auto error = std::get<ErrorType>(result);
218     switch (error) {
219       case ErrorType::MissingDeviceKernel:
220         TORCH_INTERNAL_ASSERT(
221             false, "DispatchStub: missing kernel for ", device_type);
222         return nullptr;
223       case ErrorType::DeviceNotSupported:
224         AT_ERROR("DispatchStub: unsupported device type", device_type);
225     }
226   }
227 
228   void* fptr = std::get<void*>(result);
229   return fptr;
230 }
231 
try_choose_cpu_impl(void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)232 DispatchResult DispatchStubImpl::try_choose_cpu_impl(
233     void *DEFAULT
234 #ifdef HAVE_AVX512_CPU_DEFINITION
235     , void *AVX512
236 #endif
237 #ifdef HAVE_AVX2_CPU_DEFINITION
238     , void *AVX2
239 #endif
240 #ifdef HAVE_VSX_CPU_DEFINITION
241     , void *VSX
242 #endif
243 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
244     , void *ZVECTOR
245 #endif
246   ){
247 
248   auto capability = static_cast<int>(get_cpu_capability());
249   (void)capability;
250 #ifdef HAVE_AVX512_CPU_DEFINITION
251   if (capability >= static_cast<int>(CPUCapability::AVX512)) {
252     // Quantization kernels have also been disabled on Windows
253     // for AVX512 because some of their tests are flaky on Windows.
254     // Ideally, we should have AVX512 kernels for all kernels.
255     if (C10_UNLIKELY(!AVX512)) {
256       // dispatch to AVX2, since the AVX512 kernel is missing
257       return AVX2 != nullptr ? DispatchResult(AVX2) : ErrorType::MissingDeviceKernel;
258     } else {
259       return DispatchResult(AVX512);
260     }
261   }
262 #endif
263 #ifdef HAVE_AVX2_CPU_DEFINITION
264   if (capability >= static_cast<int>(CPUCapability::AVX2)) {
265     return AVX2 != nullptr ? DispatchResult(AVX2) : ErrorType::MissingDeviceKernel;
266   }
267 #endif
268 #ifdef HAVE_VSX_CPU_DEFINITION
269   if (capability >= static_cast<int>(CPUCapability::VSX)) {
270     return VSX != nullptr ? DispatchResult(VSX) : ErrorType::MissingDeviceKernel;
271   }
272 #endif
273 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
274   if (capability >= static_cast<int>(CPUCapability::ZVECTOR)) {
275     return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel;
276   }
277 #endif
278   return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
279 }
280 
choose_cpu_impl(void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)281 void* DispatchStubImpl::choose_cpu_impl(
282   void *DEFAULT
283 #ifdef HAVE_AVX512_CPU_DEFINITION
284   , void *AVX512
285 #endif
286 #ifdef HAVE_AVX2_CPU_DEFINITION
287   , void *AVX2
288 #endif
289 #ifdef HAVE_VSX_CPU_DEFINITION
290   , void *VSX
291 #endif
292 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
293   , void *ZVECTOR
294 #endif
295 ) {
296   auto capability = static_cast<int>(get_cpu_capability());
297   (void)capability;
298 #ifdef HAVE_AVX512_CPU_DEFINITION
299   if (capability >= static_cast<int>(CPUCapability::AVX512)) {
300     // Quantization kernels have also been disabled on Windows
301     // for AVX512 because some of their tests are flaky on Windows.
302     // Ideally, we should have AVX512 kernels for all kernels.
303     if (C10_UNLIKELY(!AVX512)) {
304       // dispatch to AVX2, since the AVX512 kernel is missing
305       TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
306       return AVX2;
307     } else {
308       return AVX512;
309     }
310   }
311 #endif
312 #ifdef HAVE_AVX2_CPU_DEFINITION
313   if (capability >= static_cast<int>(CPUCapability::AVX2)) {
314     TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
315     return AVX2;
316   }
317 #endif
318 #ifdef HAVE_VSX_CPU_DEFINITION
319   if (capability >= static_cast<int>(CPUCapability::VSX)) {
320     TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel");
321     return VSX;
322   }
323 #endif
324 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
325   if (capability >= static_cast<int>(CPUCapability::ZVECTOR)) {
326     TORCH_INTERNAL_ASSERT(ZVECTOR, "DispatchStub: missing ZVECTOR kernel");
327     return ZVECTOR;
328   }
329 #endif
330   TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
331   return DEFAULT;
332 }
333 
334 }  // namespace at::native
335