1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/DispatchStub.h>
3
4 #include <c10/core/DeviceType.h>
5 #include <c10/util/Exception.h>
6
7 #if !defined(__s390x__) && !defined(__powerpc__)
8 #include <cpuinfo.h>
9 #endif
10 #include <cstdlib>
11 #include <cstring>
12
13 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
14 #include <sys/auxv.h>
15 #endif
16
17 namespace at::native {
18
19 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
cpu_has_vxe()20 static inline bool cpu_has_vxe()
21 {
22 return (getauxval(AT_HWCAP) & HWCAP_S390_VXE);
23 }
24 #endif
25
compute_cpu_capability()26 static CPUCapability compute_cpu_capability() {
27 auto envar = std::getenv("ATEN_CPU_CAPABILITY");
28 if (envar) {
29 #if defined(HAVE_VSX_CPU_DEFINITION)
30 if (strcmp(envar, "vsx") == 0) {
31 return CPUCapability::VSX;
32 }
33 #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
34 if (strcmp(envar, "zvector") == 0) {
35 return CPUCapability::ZVECTOR;
36 }
37 #else
38 #ifdef HAVE_AVX512_CPU_DEFINITION
39 if (strcmp(envar, "avx512") == 0) {
40 return CPUCapability::AVX512;
41 }
42 #endif
43 #ifdef HAVE_AVX2_CPU_DEFINITION
44 if (strcmp(envar, "avx2") == 0) {
45 return CPUCapability::AVX2;
46 }
47 #endif
48 #endif
49 if (strcmp(envar, "default") == 0) {
50 return CPUCapability::DEFAULT;
51 }
52 TORCH_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar);
53 }
54
55 #if !defined(__powerpc__) && !defined(__s390x__)
56 if (cpuinfo_initialize()) {
57 #if defined(HAVE_AVX512_CPU_DEFINITION)
58 // GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in
59 // versions 9 & beyond. So, we want to ensure that only releases built with
60 // supported compilers on supported hardware return CPU Capability AVX512,
61 // if it's supported on the hardware PyTorch is running on.
62 if (cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && \
63 cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_fma3()) {
64 return CPUCapability::AVX512;
65 }
66 #endif
67 #ifdef HAVE_AVX2_CPU_DEFINITION
68 if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
69 return CPUCapability::AVX2;
70 }
71 #endif
72 }
73 #endif
74
75 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
76 // vxe is needed for fp32 vector instructions
77 if (cpu_has_vxe()) {
78 return CPUCapability::ZVECTOR;
79 }
80 #endif
81
82 #ifdef HAVE_VSX_CPU_DEFINITION
83 return CPUCapability::VSX;
84 #else
85 return CPUCapability::DEFAULT;
86 #endif
87 }
88
get_cpu_capability()89 CPUCapability get_cpu_capability() {
90 static CPUCapability capability = compute_cpu_capability();
91 return capability;
92 }
93
try_get_call_ptr(const DeviceType device_type,void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)94 DispatchResult DispatchStubImpl::try_get_call_ptr(
95 const DeviceType device_type
96 , void *DEFAULT
97 #ifdef HAVE_AVX512_CPU_DEFINITION
98 , void *AVX512
99 #endif
100 #ifdef HAVE_AVX2_CPU_DEFINITION
101 , void *AVX2
102 #endif
103 #ifdef HAVE_VSX_CPU_DEFINITION
104 , void *VSX
105 #endif
106 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
107 , void *ZVECTOR
108 #endif
109 ) {
110 constexpr auto supported_devices = c10::array_of<c10::DeviceType>(
111 c10::DeviceType::CPU,
112 c10::DeviceType::CUDA,
113 c10::DeviceType::HIP,
114 c10::DeviceType::MPS,
115 c10::DeviceType::MTIA,
116 c10::DeviceType::XPU,
117 c10::DeviceType::PrivateUse1
118 );
119 // Check if the device type is supported.
120 if (std::find(supported_devices.begin(), supported_devices.end(), device_type) == supported_devices.end()) {
121 return ErrorType::DeviceNotSupported;
122 }
123 switch (device_type) {
124 case DeviceType::CPU: {
125 // Use memory_order_relaxed here since even if two threads race,
126 // they will still compute the same value for cpu_dispatch_ptr.
127 auto fptr = cpu_dispatch_ptr.load(std::memory_order_relaxed);
128 if (!fptr) {
129 auto result = try_choose_cpu_impl(
130 DEFAULT
131 #ifdef HAVE_AVX512_CPU_DEFINITION
132 , AVX512
133 #endif
134 #ifdef HAVE_AVX2_CPU_DEFINITION
135 , AVX2
136 #endif
137 #ifdef HAVE_VSX_CPU_DEFINITION
138 , VSX
139 #endif
140 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
141 , ZVECTOR
142 #endif
143 );
144 if (!std::holds_alternative<ErrorType>(result)) {
145 cpu_dispatch_ptr.store(fptr, std::memory_order_relaxed);
146 }
147 return result;
148 }
149 return DispatchResult(fptr);
150 }
151
152 case DeviceType::CUDA:
153 return cuda_dispatch_ptr != nullptr ? DispatchResult(cuda_dispatch_ptr) : ErrorType::MissingDeviceKernel;
154
155 case DeviceType::HIP:
156 return hip_dispatch_ptr != nullptr ? DispatchResult(hip_dispatch_ptr) : ErrorType::MissingDeviceKernel;
157
158 #if defined(USE_MPS)
159 case DeviceType::MPS:
160 return mps_dispatch_ptr != nullptr ? DispatchResult(mps_dispatch_ptr) : ErrorType::MissingDeviceKernel;
161 #endif
162 case DeviceType::MTIA:
163 return mtia_dispatch_ptr != nullptr ? DispatchResult(mtia_dispatch_ptr) : ErrorType::MissingDeviceKernel;
164
165 #if defined(USE_XPU)
166 case DeviceType::XPU:
167 return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
168 #endif
169
170 case DeviceType::PrivateUse1:
171 return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;
172
173 default:
174 TORCH_INTERNAL_ASSERT(false, "An unexpected device type was provided ", device_type);
175 return ErrorType::DeviceNotSupported;
176 }
177 }
178
get_call_ptr(const DeviceType device_type,void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)179 void* DispatchStubImpl::get_call_ptr(
180 const DeviceType device_type
181 , void *DEFAULT
182 #ifdef HAVE_AVX512_CPU_DEFINITION
183 , void *AVX512
184 #endif
185 #ifdef HAVE_AVX2_CPU_DEFINITION
186 , void *AVX2
187 #endif
188 #ifdef HAVE_VSX_CPU_DEFINITION
189 , void *VSX
190 #endif
191 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
192 , void *ZVECTOR
193 #endif
194 ) {
195
196 auto result = try_get_call_ptr(
197 device_type,
198 DEFAULT
199 #ifdef HAVE_AVX512_CPU_DEFINITION
200 ,
201 AVX512
202 #endif
203 #ifdef HAVE_AVX2_CPU_DEFINITION
204 ,
205 AVX2
206 #endif
207 #ifdef HAVE_VSX_CPU_DEFINITION
208 ,
209 VSX
210 #endif
211 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
212 ,
213 ZVECTOR
214 #endif
215 );
216 if (std::holds_alternative<ErrorType>(result)) {
217 auto error = std::get<ErrorType>(result);
218 switch (error) {
219 case ErrorType::MissingDeviceKernel:
220 TORCH_INTERNAL_ASSERT(
221 false, "DispatchStub: missing kernel for ", device_type);
222 return nullptr;
223 case ErrorType::DeviceNotSupported:
224 AT_ERROR("DispatchStub: unsupported device type", device_type);
225 }
226 }
227
228 void* fptr = std::get<void*>(result);
229 return fptr;
230 }
231
try_choose_cpu_impl(void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)232 DispatchResult DispatchStubImpl::try_choose_cpu_impl(
233 void *DEFAULT
234 #ifdef HAVE_AVX512_CPU_DEFINITION
235 , void *AVX512
236 #endif
237 #ifdef HAVE_AVX2_CPU_DEFINITION
238 , void *AVX2
239 #endif
240 #ifdef HAVE_VSX_CPU_DEFINITION
241 , void *VSX
242 #endif
243 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
244 , void *ZVECTOR
245 #endif
246 ){
247
248 auto capability = static_cast<int>(get_cpu_capability());
249 (void)capability;
250 #ifdef HAVE_AVX512_CPU_DEFINITION
251 if (capability >= static_cast<int>(CPUCapability::AVX512)) {
252 // Quantization kernels have also been disabled on Windows
253 // for AVX512 because some of their tests are flaky on Windows.
254 // Ideally, we should have AVX512 kernels for all kernels.
255 if (C10_UNLIKELY(!AVX512)) {
256 // dispatch to AVX2, since the AVX512 kernel is missing
257 return AVX2 != nullptr ? DispatchResult(AVX2) : ErrorType::MissingDeviceKernel;
258 } else {
259 return DispatchResult(AVX512);
260 }
261 }
262 #endif
263 #ifdef HAVE_AVX2_CPU_DEFINITION
264 if (capability >= static_cast<int>(CPUCapability::AVX2)) {
265 return AVX2 != nullptr ? DispatchResult(AVX2) : ErrorType::MissingDeviceKernel;
266 }
267 #endif
268 #ifdef HAVE_VSX_CPU_DEFINITION
269 if (capability >= static_cast<int>(CPUCapability::VSX)) {
270 return VSX != nullptr ? DispatchResult(VSX) : ErrorType::MissingDeviceKernel;
271 }
272 #endif
273 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
274 if (capability >= static_cast<int>(CPUCapability::ZVECTOR)) {
275 return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel;
276 }
277 #endif
278 return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
279 }
280
choose_cpu_impl(void * DEFAULT,void * AVX512,void * AVX2,void * VSX,void * ZVECTOR)281 void* DispatchStubImpl::choose_cpu_impl(
282 void *DEFAULT
283 #ifdef HAVE_AVX512_CPU_DEFINITION
284 , void *AVX512
285 #endif
286 #ifdef HAVE_AVX2_CPU_DEFINITION
287 , void *AVX2
288 #endif
289 #ifdef HAVE_VSX_CPU_DEFINITION
290 , void *VSX
291 #endif
292 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
293 , void *ZVECTOR
294 #endif
295 ) {
296 auto capability = static_cast<int>(get_cpu_capability());
297 (void)capability;
298 #ifdef HAVE_AVX512_CPU_DEFINITION
299 if (capability >= static_cast<int>(CPUCapability::AVX512)) {
300 // Quantization kernels have also been disabled on Windows
301 // for AVX512 because some of their tests are flaky on Windows.
302 // Ideally, we should have AVX512 kernels for all kernels.
303 if (C10_UNLIKELY(!AVX512)) {
304 // dispatch to AVX2, since the AVX512 kernel is missing
305 TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
306 return AVX2;
307 } else {
308 return AVX512;
309 }
310 }
311 #endif
312 #ifdef HAVE_AVX2_CPU_DEFINITION
313 if (capability >= static_cast<int>(CPUCapability::AVX2)) {
314 TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
315 return AVX2;
316 }
317 #endif
318 #ifdef HAVE_VSX_CPU_DEFINITION
319 if (capability >= static_cast<int>(CPUCapability::VSX)) {
320 TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel");
321 return VSX;
322 }
323 #endif
324 #ifdef HAVE_ZVECTOR_CPU_DEFINITION
325 if (capability >= static_cast<int>(CPUCapability::ZVECTOR)) {
326 TORCH_INTERNAL_ASSERT(ZVECTOR, "DispatchStub: missing ZVECTOR kernel");
327 return ZVECTOR;
328 }
329 #endif
330 TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
331 return DEFAULT;
332 }
333
334 } // namespace at::native
335