xref: /aosp_15_r20/external/pytorch/cmake/Modules/FindAVX.cmake (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard WorkerINCLUDE(CheckCSourceRuns)
2*da0073e9SAndroid Build Coastguard WorkerINCLUDE(CheckCSourceCompiles)
3*da0073e9SAndroid Build Coastguard WorkerINCLUDE(CheckCXXSourceRuns)
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard WorkerSET(AVX_CODE "
6*da0073e9SAndroid Build Coastguard Worker  #include <immintrin.h>
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker  int main()
9*da0073e9SAndroid Build Coastguard Worker  {
10*da0073e9SAndroid Build Coastguard Worker    __m256 a;
11*da0073e9SAndroid Build Coastguard Worker    a = _mm256_set1_ps(0);
12*da0073e9SAndroid Build Coastguard Worker    return 0;
13*da0073e9SAndroid Build Coastguard Worker  }
14*da0073e9SAndroid Build Coastguard Worker")
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard WorkerSET(AVX512_CODE "
17*da0073e9SAndroid Build Coastguard Worker  #include <immintrin.h>
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker  int main()
20*da0073e9SAndroid Build Coastguard Worker  {
21*da0073e9SAndroid Build Coastguard Worker    __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
22*da0073e9SAndroid Build Coastguard Worker                                0, 0, 0, 0, 0, 0, 0, 0,
23*da0073e9SAndroid Build Coastguard Worker                                0, 0, 0, 0, 0, 0, 0, 0,
24*da0073e9SAndroid Build Coastguard Worker                                0, 0, 0, 0, 0, 0, 0, 0,
25*da0073e9SAndroid Build Coastguard Worker                                0, 0, 0, 0, 0, 0, 0, 0,
26*da0073e9SAndroid Build Coastguard Worker                                0, 0, 0, 0, 0, 0, 0, 0,
27*da0073e9SAndroid Build Coastguard Worker                                0, 0, 0, 0, 0, 0, 0, 0,
28*da0073e9SAndroid Build Coastguard Worker                                0, 0, 0, 0, 0, 0, 0, 0);
29*da0073e9SAndroid Build Coastguard Worker    __m512i b = a;
30*da0073e9SAndroid Build Coastguard Worker    __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
31*da0073e9SAndroid Build Coastguard Worker    return 0;
32*da0073e9SAndroid Build Coastguard Worker  }
33*da0073e9SAndroid Build Coastguard Worker")
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard WorkerSET(AVX2_CODE "
36*da0073e9SAndroid Build Coastguard Worker  #include <immintrin.h>
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker  int main()
39*da0073e9SAndroid Build Coastguard Worker  {
40*da0073e9SAndroid Build Coastguard Worker    __m256i a = {0};
41*da0073e9SAndroid Build Coastguard Worker    a = _mm256_abs_epi16(a);
42*da0073e9SAndroid Build Coastguard Worker    __m256i x;
43*da0073e9SAndroid Build Coastguard Worker    _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
44*da0073e9SAndroid Build Coastguard Worker    return 0;
45*da0073e9SAndroid Build Coastguard Worker  }
46*da0073e9SAndroid Build Coastguard Worker")
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard WorkerMACRO(CHECK_SSE lang type flags)
49*da0073e9SAndroid Build Coastguard Worker  SET(__FLAG_I 1)
50*da0073e9SAndroid Build Coastguard Worker  SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
51*da0073e9SAndroid Build Coastguard Worker  FOREACH(__FLAG ${flags})
52*da0073e9SAndroid Build Coastguard Worker    IF(NOT ${lang}_${type}_FOUND)
53*da0073e9SAndroid Build Coastguard Worker      SET(CMAKE_REQUIRED_FLAGS ${__FLAG})
54*da0073e9SAndroid Build Coastguard Worker      IF(lang STREQUAL "CXX")
55*da0073e9SAndroid Build Coastguard Worker        CHECK_CXX_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I})
56*da0073e9SAndroid Build Coastguard Worker      ELSE()
57*da0073e9SAndroid Build Coastguard Worker        CHECK_C_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I})
58*da0073e9SAndroid Build Coastguard Worker      ENDIF()
59*da0073e9SAndroid Build Coastguard Worker      IF(${lang}_HAS_${type}_${__FLAG_I})
60*da0073e9SAndroid Build Coastguard Worker        SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support")
61*da0073e9SAndroid Build Coastguard Worker        SET(${lang}_${type}_FLAGS "${__FLAG}" CACHE STRING "${lang} ${type} flags")
62*da0073e9SAndroid Build Coastguard Worker      ENDIF()
63*da0073e9SAndroid Build Coastguard Worker      MATH(EXPR __FLAG_I "${__FLAG_I}+1")
64*da0073e9SAndroid Build Coastguard Worker    ENDIF()
65*da0073e9SAndroid Build Coastguard Worker  ENDFOREACH()
66*da0073e9SAndroid Build Coastguard Worker  SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker  IF(NOT ${lang}_${type}_FOUND)
69*da0073e9SAndroid Build Coastguard Worker    SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support")
70*da0073e9SAndroid Build Coastguard Worker    SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags")
71*da0073e9SAndroid Build Coastguard Worker  ENDIF()
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker  MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard WorkerENDMACRO()
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard WorkerCHECK_SSE(C "AVX" " ;-mavx;/arch:AVX")
78*da0073e9SAndroid Build Coastguard WorkerCHECK_SSE(C "AVX2" " ;-mavx2 -mfma -mf16c;/arch:AVX2")
79*da0073e9SAndroid Build Coastguard WorkerCHECK_SSE(C "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard WorkerCHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX")
82*da0073e9SAndroid Build Coastguard WorkerCHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma -mf16c;/arch:AVX2")
83*da0073e9SAndroid Build Coastguard WorkerCHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
84