1#!/usr/bin/env python3 2 3from __future__ import print_function 4import collections 5import os 6import sys 7import logging 8 9BANNER = "Auto-generated by generate-wrappers.py script. Do not modify" 10WRAPPER_SRC_NAMES = { 11 "PROD_SCALAR_MICROKERNEL_SRCS": None, 12 "PROD_FMA_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)", 13 "PROD_ARMSIMD32_MICROKERNEL_SRCS": "defined(__arm__)", 14 "PROD_FP16ARITH_MICROKERNEL_SRCS": "defined(__arm__)", 15 "PROD_NEON_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 16 "PROD_NEONFP16_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 17 "PROD_NEONFMA_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 18 "PROD_NEON_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)", 19 "PROD_NEONV8_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 20 "PROD_NEONFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 21 "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)", 22 "PROD_NEONDOT_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 23 "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)", 24 "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 25 "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)", 26 "PROD_NEONI8MM_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)", 27 "PROD_SSE_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 28 "PROD_SSE2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 29 "PROD_SSSE3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 30 "PROD_SSE41_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 31 "PROD_AVX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 32 "PROD_F16C_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 33 "PROD_XOP_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 34 "PROD_FMA3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 35 "PROD_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 36 "PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 37 "PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 38 "PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 39 "PROD_AVX512VNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 40 "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 41 "PROD_RVV_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)", 42 "PROD_AVXVNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", 43 "AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)", 44 "AARCH64_ASM_MICROKERNEL_SRCS": "defined(__aarch64__)", 45 46 # add non-prod microkernel sources here: 47} 48 49SRC_NAMES = { 50 "OPERATOR_SRCS", 51 "SUBGRAPH_SRCS", 52 "LOGGING_SRCS", 53 "XNNPACK_SRCS", 54 "TABLE_SRCS", 55 "JIT_SRCS", 56 "PROD_SCALAR_MICROKERNEL_SRCS", 57 "PROD_FMA_MICROKERNEL_SRCS", 58 "PROD_ARMSIMD32_MICROKERNEL_SRCS", 59 "PROD_FP16ARITH_MICROKERNEL_SRCS", 60 "PROD_NEON_MICROKERNEL_SRCS", 61 "PROD_NEONFP16_MICROKERNEL_SRCS", 62 "PROD_NEONFMA_MICROKERNEL_SRCS", 63 "PROD_NEON_AARCH64_MICROKERNEL_SRCS", 64 "PROD_NEONV8_MICROKERNEL_SRCS", 65 "PROD_NEONFP16ARITH_MICROKERNEL_SRCS", 66 "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS", 67 "PROD_NEONDOT_MICROKERNEL_SRCS", 68 "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS", 69 "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS", 70 "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS", 71 "PROD_NEONI8MM_MICROKERNEL_SRCS", 72 "PROD_SSE_MICROKERNEL_SRCS", 73 "PROD_SSE2_MICROKERNEL_SRCS", 74 "PROD_SSSE3_MICROKERNEL_SRCS", 75 "PROD_SSE41_MICROKERNEL_SRCS", 76 "PROD_AVX_MICROKERNEL_SRCS", 77 "PROD_F16C_MICROKERNEL_SRCS", 78 "PROD_XOP_MICROKERNEL_SRCS", 79 "PROD_FMA3_MICROKERNEL_SRCS", 80 "PROD_AVX2_MICROKERNEL_SRCS", 81 "PROD_AVX512F_MICROKERNEL_SRCS", 82 "PROD_AVX512SKX_MICROKERNEL_SRCS", 83 "PROD_AVX512VBMI_MICROKERNEL_SRCS", 84 "PROD_AVX512VNNI_MICROKERNEL_SRCS", 85 "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS", 86 "PROD_RVV_MICROKERNEL_SRCS", 87 "PROD_AVXVNNI_MICROKERNEL_SRCS", 88 "AARCH32_ASM_MICROKERNEL_SRCS", 89 "AARCH64_ASM_MICROKERNEL_SRCS", 90 91 # add non-prod microkernel sources here: 92} 93 94def handle_singleline_parse(line): 95 start_index = line.find("(") 96 end_index = line.find(")") 97 line = line[start_index+1:end_index] 98 key_val = line.split(" ") 99 return key_val[0], [x[4:] for x in key_val[1:]] 100 101def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"): 102 sources = collections.defaultdict(list) 103 with open(os.path.join(xnnpack_path, cmakefile)) as cmake: 104 lines = cmake.readlines() 105 i = 0 106 while i < len(lines): 107 line = lines[i] 108 109 if lines[i].startswith("SET") and "src/" in lines[i]: 110 name, val = handle_singleline_parse(line) 111 sources[name].extend(val) 112 i+=1 113 continue 114 115 if line.startswith("SET") and line.split('(')[1].strip(' \t\n\r') in set(WRAPPER_SRC_NAMES.keys()) | set(SRC_NAMES): 116 name = line.split('(')[1].strip(' \t\n\r') 117 i += 1 118 while i < len(lines) and len(lines[i]) > 0 and ')' not in lines[i]: 119 # remove "src/" at the beginning, remove whitespaces and newline 120 value = lines[i].strip(' \t\n\r') 121 sources[name].append(value[4:]) 122 i += 1 123 if i < len(lines) and len(lines[i]) > 4: 124 # remove "src/" at the beginning, possibly ')' at the end 125 value = lines[i].strip(' \t\n\r)') 126 sources[name].append(value[4:]) 127 else: 128 i += 1 129 return sources 130 131def gen_wrappers(xnnpack_path): 132 xnnpack_sources = collections.defaultdict(list) 133 sources = update_sources(xnnpack_path) 134 135 microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake") 136 for key in microkernels_sources: 137 sources[key] = microkernels_sources[key] 138 139 for name in WRAPPER_SRC_NAMES: 140 xnnpack_sources[WRAPPER_SRC_NAMES[name]].extend(sources[name]) 141 142 for condition, filenames in xnnpack_sources.items(): 143 print(condition) 144 for filename in filenames: 145 filepath = os.path.join(xnnpack_path, "xnnpack_wrappers", filename) 146 147 if not os.path.isdir(os.path.dirname(filepath)): 148 os.makedirs(os.path.dirname(filepath)) 149 with open(filepath, "w") as wrapper: 150 print("/* {} */".format(BANNER), file=wrapper) 151 print(file=wrapper) 152 153 # Architecture- or platform-dependent preprocessor flags can be 154 # defined here. Note: platform_preprocessor_flags can't be used 155 # because they are ignored by arc focus & buck project. 156 157 if condition is None: 158 print("#include <%s>" % filename, file=wrapper) 159 else: 160 # Include source file only if condition is satisfied 161 print("#if %s" % condition, file=wrapper) 162 print("#include <%s>" % filename, file=wrapper) 163 print("#endif /* %s */" % condition, file=wrapper) 164 165 # update xnnpack_wrapper_defs.bzl file under the same folder 166 with open(os.path.join(os.path.dirname(__file__), "xnnpack_wrapper_defs.bzl"), 'w') as wrapper_defs: 167 print('"""', file=wrapper_defs) 168 print(BANNER, file=wrapper_defs) 169 print('"""', file=wrapper_defs) 170 for name in WRAPPER_SRC_NAMES: 171 print('\n' + name + ' = [', file=wrapper_defs) 172 for file_name in sources[name]: 173 print(' "xnnpack_wrappers/{}",'.format(file_name), file=wrapper_defs) 174 print(']', file=wrapper_defs) 175 176 # update xnnpack_src_defs.bzl file under the same folder 177 with open(os.path.join(os.path.dirname(__file__), "xnnpack_src_defs.bzl"), 'w') as src_defs: 178 print('"""', file=src_defs) 179 print(BANNER, file=src_defs) 180 print('"""', file=src_defs) 181 for name in SRC_NAMES: 182 print('\n' + name + ' = [', file=src_defs) 183 for file_name in sources[name]: 184 print(' "XNNPACK/src/{}",'.format(file_name), file=src_defs) 185 print(']', file=src_defs) 186 187 188def main(argv): 189 if argv is None or len(argv) == 0: 190 gen_wrappers(".") 191 else: 192 gen_wrappers(argv[0]) 193 194# The first argument is the place where the "xnnpack_wrappers" folder will be created. 195# Run it without arguments will generate "xnnpack_wrappers" in the current path. 196# The two .bzl files will always be generated in the current path. 197if __name__ == "__main__": 198 main(sys.argv[1:]) 199