1 // Inspired by 2 // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 4 5 #pragma once 6 7 /// @param COND - a boolean expression to switch by 8 /// @param CONST_NAME - a name given for the constexpr bool variable. 9 /// @param ... - code to execute for true and false 10 /// 11 /// Usage: 12 /// ``` 13 /// BOOL_SWITCH(flag, BoolConst, [&] { 14 /// some_function<BoolConst>(...); 15 /// }); 16 /// ``` 17 18 #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 19 [&] { \ 20 if (COND) { \ 21 constexpr static bool CONST_NAME = true; \ 22 return __VA_ARGS__(); \ 23 } else { \ 24 constexpr static bool CONST_NAME = false; \ 25 return __VA_ARGS__(); \ 26 } \ 27 }() 28 29 #ifdef FLASHATTENTION_DISABLE_DROPOUT 30 #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ 31 [&] { \ 32 constexpr static bool CONST_NAME = false; \ 33 return __VA_ARGS__(); \ 34 }() 35 #else 36 #define DROPOUT_SWITCH BOOL_SWITCH 37 #endif 38 39 #ifdef FLASHATTENTION_DISABLE_ALIBI 40 #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ 41 [&] { \ 42 constexpr static bool CONST_NAME = false; \ 43 return __VA_ARGS__(); \ 44 }() 45 #else 46 #define ALIBI_SWITCH BOOL_SWITCH 47 #endif 48 49 #ifdef FLASHATTENTION_DISABLE_UNEVEN_K 50 #define EVENK_SWITCH(COND, CONST_NAME, ...) \ 51 [&] { \ 52 constexpr static bool CONST_NAME = true; \ 53 return __VA_ARGS__(); \ 54 }() 55 #else 56 #define EVENK_SWITCH BOOL_SWITCH 57 #endif 58 59 #ifdef FLASHATTENTION_DISABLE_LOCAL 60 #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ 61 [&] { \ 62 constexpr static bool CONST_NAME = false; \ 63 return __VA_ARGS__(); \ 64 }() 65 #else 66 #define LOCAL_SWITCH BOOL_SWITCH 67 #endif 68 69 #define FP16_SWITCH(COND, ...) \ 70 [&] { \ 71 if (COND) { \ 72 using elem_type = cutlass::half_t; \ 73 return __VA_ARGS__(); \ 74 } else { \ 75 using elem_type = cutlass::bfloat16_t; \ 76 return __VA_ARGS__(); \ 77 } \ 78 }() 79 80 #define HEADDIM_SWITCH(HEADDIM, ...) \ 81 [&] { \ 82 if (HEADDIM <= 32) { \ 83 constexpr static int kHeadDim = 32; \ 84 return __VA_ARGS__(); \ 85 } else if (HEADDIM <= 64) { \ 86 constexpr static int kHeadDim = 64; \ 87 return __VA_ARGS__(); \ 88 } else if (HEADDIM <= 96) { \ 89 constexpr static int kHeadDim = 96; \ 90 return __VA_ARGS__(); \ 91 } else if (HEADDIM <= 128) { \ 92 constexpr static int kHeadDim = 128; \ 93 return __VA_ARGS__(); \ 94 } else if (HEADDIM <= 160) { \ 95 constexpr static int kHeadDim = 160; \ 96 return __VA_ARGS__(); \ 97 } else if (HEADDIM <= 192) { \ 98 constexpr static int kHeadDim = 192; \ 99 return __VA_ARGS__(); \ 100 } else if (HEADDIM <= 224) { \ 101 constexpr static int kHeadDim = 224; \ 102 return __VA_ARGS__(); \ 103 } else if (HEADDIM <= 256) { \ 104 constexpr static int kHeadDim = 256; \ 105 return __VA_ARGS__(); \ 106 } \ 107 }() 108