xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Inspired by
2 // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
3 // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
4 
5 #pragma once
6 
7 /// @param COND       - a boolean expression to switch by
8 /// @param CONST_NAME - a name given for the constexpr bool variable.
9 /// @param ...       - code to execute for true and false
10 ///
11 /// Usage:
12 /// ```
13 /// BOOL_SWITCH(flag, BoolConst, [&] {
14 ///     some_function<BoolConst>(...);
15 /// });
16 /// ```
17 
18 #define BOOL_SWITCH(COND, CONST_NAME, ...)      \
19   [&] {                                         \
20     if (COND) {                                 \
21       constexpr static bool CONST_NAME = true;  \
22       return __VA_ARGS__();                     \
23     } else {                                    \
24       constexpr static bool CONST_NAME = false; \
25       return __VA_ARGS__();                     \
26     }                                           \
27   }()
28 
29 #ifdef FLASHATTENTION_DISABLE_DROPOUT
30   #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
31   [&] {                                         \
32     constexpr static bool CONST_NAME = false;   \
33     return __VA_ARGS__();                       \
34   }()
35 #else
36   #define DROPOUT_SWITCH BOOL_SWITCH
37 #endif
38 
39 #ifdef FLASHATTENTION_DISABLE_ALIBI
40   #define ALIBI_SWITCH(COND, CONST_NAME, ...)   \
41   [&] {                                         \
42     constexpr static bool CONST_NAME = false;   \
43     return __VA_ARGS__();                       \
44   }()
45 #else
46   #define ALIBI_SWITCH BOOL_SWITCH
47 #endif
48 
49 #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
50   #define EVENK_SWITCH(COND, CONST_NAME, ...)   \
51   [&] {                                         \
52     constexpr static bool CONST_NAME = true;    \
53     return __VA_ARGS__();                       \
54   }()
55 #else
56   #define EVENK_SWITCH BOOL_SWITCH
57 #endif
58 
59 #ifdef FLASHATTENTION_DISABLE_LOCAL
60   #define LOCAL_SWITCH(COND, CONST_NAME, ...)   \
61   [&] {                                         \
62     constexpr static bool CONST_NAME = false;    \
63     return __VA_ARGS__();                       \
64   }()
65 #else
66   #define LOCAL_SWITCH BOOL_SWITCH
67 #endif
68 
69 #define FP16_SWITCH(COND, ...)               \
70   [&] {                                      \
71     if (COND) {                              \
72       using elem_type = cutlass::half_t;     \
73       return __VA_ARGS__();                  \
74     } else {                                 \
75       using elem_type = cutlass::bfloat16_t; \
76       return __VA_ARGS__();                  \
77     }                                        \
78   }()
79 
80 #define HEADDIM_SWITCH(HEADDIM, ...)   \
81   [&] {                                    \
82     if (HEADDIM <= 32) {                   \
83       constexpr static int kHeadDim = 32;  \
84       return __VA_ARGS__();                \
85     } else if (HEADDIM <= 64) {            \
86       constexpr static int kHeadDim = 64;  \
87       return __VA_ARGS__();                \
88     } else if (HEADDIM <= 96) {            \
89       constexpr static int kHeadDim = 96;  \
90       return __VA_ARGS__();                \
91     } else if (HEADDIM <= 128) {           \
92       constexpr static int kHeadDim = 128; \
93       return __VA_ARGS__();                \
94     } else if (HEADDIM <= 160) {           \
95       constexpr static int kHeadDim = 160; \
96       return __VA_ARGS__();                \
97     } else if (HEADDIM <= 192) {           \
98       constexpr static int kHeadDim = 192; \
99       return __VA_ARGS__();                \
100     } else if (HEADDIM <= 224) {           \
101       constexpr static int kHeadDim = 224; \
102       return __VA_ARGS__();                \
103     } else if (HEADDIM <= 256) {           \
104       constexpr static int kHeadDim = 256; \
105       return __VA_ARGS__();                \
106     }                                      \
107   }()
108