xref: /aosp_15_r20/external/pytorch/third_party/generate-xnnpack-wrappers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import print_function
4import collections
5import os
6import sys
7import logging
8
9BANNER = "Auto-generated by generate-wrappers.py script. Do not modify"
10WRAPPER_SRC_NAMES = {
11    "PROD_SCALAR_MICROKERNEL_SRCS": None,
12    "PROD_FMA_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
13    "PROD_ARMSIMD32_MICROKERNEL_SRCS": "defined(__arm__)",
14    "PROD_FP16ARITH_MICROKERNEL_SRCS": "defined(__arm__)",
15    "PROD_NEON_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
16    "PROD_NEONFP16_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
17    "PROD_NEONFMA_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
18    "PROD_NEON_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
19    "PROD_NEONV8_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
20    "PROD_NEONFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
21    "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
22    "PROD_NEONDOT_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
23    "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
24    "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
25    "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
26    "PROD_NEONI8MM_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
27    "PROD_SSE_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
28    "PROD_SSE2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
29    "PROD_SSSE3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
30    "PROD_SSE41_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
31    "PROD_AVX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
32    "PROD_F16C_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
33    "PROD_XOP_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
34    "PROD_FMA3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
35    "PROD_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
36    "PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
37    "PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
38    "PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
39    "PROD_AVX512VNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
40    "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
41    "PROD_RVV_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
42    "PROD_AVXVNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
43    "AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)",
44    "AARCH64_ASM_MICROKERNEL_SRCS": "defined(__aarch64__)",
45
46    # add non-prod microkernel sources here:
47}
48
49SRC_NAMES = {
50    "OPERATOR_SRCS",
51    "SUBGRAPH_SRCS",
52    "LOGGING_SRCS",
53    "XNNPACK_SRCS",
54    "TABLE_SRCS",
55    "JIT_SRCS",
56    "PROD_SCALAR_MICROKERNEL_SRCS",
57    "PROD_FMA_MICROKERNEL_SRCS",
58    "PROD_ARMSIMD32_MICROKERNEL_SRCS",
59    "PROD_FP16ARITH_MICROKERNEL_SRCS",
60    "PROD_NEON_MICROKERNEL_SRCS",
61    "PROD_NEONFP16_MICROKERNEL_SRCS",
62    "PROD_NEONFMA_MICROKERNEL_SRCS",
63    "PROD_NEON_AARCH64_MICROKERNEL_SRCS",
64    "PROD_NEONV8_MICROKERNEL_SRCS",
65    "PROD_NEONFP16ARITH_MICROKERNEL_SRCS",
66    "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS",
67    "PROD_NEONDOT_MICROKERNEL_SRCS",
68    "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS",
69    "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS",
70    "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS",
71    "PROD_NEONI8MM_MICROKERNEL_SRCS",
72    "PROD_SSE_MICROKERNEL_SRCS",
73    "PROD_SSE2_MICROKERNEL_SRCS",
74    "PROD_SSSE3_MICROKERNEL_SRCS",
75    "PROD_SSE41_MICROKERNEL_SRCS",
76    "PROD_AVX_MICROKERNEL_SRCS",
77    "PROD_F16C_MICROKERNEL_SRCS",
78    "PROD_XOP_MICROKERNEL_SRCS",
79    "PROD_FMA3_MICROKERNEL_SRCS",
80    "PROD_AVX2_MICROKERNEL_SRCS",
81    "PROD_AVX512F_MICROKERNEL_SRCS",
82    "PROD_AVX512SKX_MICROKERNEL_SRCS",
83    "PROD_AVX512VBMI_MICROKERNEL_SRCS",
84    "PROD_AVX512VNNI_MICROKERNEL_SRCS",
85    "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS",
86    "PROD_RVV_MICROKERNEL_SRCS",
87    "PROD_AVXVNNI_MICROKERNEL_SRCS",
88    "AARCH32_ASM_MICROKERNEL_SRCS",
89    "AARCH64_ASM_MICROKERNEL_SRCS",
90
91    # add non-prod microkernel sources here:
92}
93
94def handle_singleline_parse(line):
95    start_index = line.find("(")
96    end_index = line.find(")")
97    line = line[start_index+1:end_index]
98    key_val = line.split(" ")
99    return key_val[0], [x[4:] for x in key_val[1:]]
100
101def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"):
102    sources = collections.defaultdict(list)
103    with open(os.path.join(xnnpack_path, cmakefile)) as cmake:
104        lines = cmake.readlines()
105        i = 0
106        while i < len(lines):
107            line = lines[i]
108
109            if lines[i].startswith("SET") and "src/" in lines[i]:
110                name, val = handle_singleline_parse(line)
111                sources[name].extend(val)
112                i+=1
113                continue
114
115            if line.startswith("SET") and line.split('(')[1].strip(' \t\n\r') in set(WRAPPER_SRC_NAMES.keys()) | set(SRC_NAMES):
116                name = line.split('(')[1].strip(' \t\n\r')
117                i += 1
118                while i < len(lines) and len(lines[i]) > 0 and ')' not in lines[i]:
119                    # remove "src/" at the beginning, remove whitespaces and newline
120                    value = lines[i].strip(' \t\n\r')
121                    sources[name].append(value[4:])
122                    i += 1
123                if i < len(lines) and len(lines[i]) > 4:
124                    # remove "src/" at the beginning, possibly ')' at the end
125                    value = lines[i].strip(' \t\n\r)')
126                    sources[name].append(value[4:])
127            else:
128                i += 1
129    return sources
130
131def gen_wrappers(xnnpack_path):
132    xnnpack_sources = collections.defaultdict(list)
133    sources = update_sources(xnnpack_path)
134
135    microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake")
136    for key in  microkernels_sources:
137        sources[key] = microkernels_sources[key]
138
139    for name in WRAPPER_SRC_NAMES:
140        xnnpack_sources[WRAPPER_SRC_NAMES[name]].extend(sources[name])
141
142    for condition, filenames in xnnpack_sources.items():
143        print(condition)
144        for filename in filenames:
145            filepath = os.path.join(xnnpack_path, "xnnpack_wrappers", filename)
146
147            if not os.path.isdir(os.path.dirname(filepath)):
148                os.makedirs(os.path.dirname(filepath))
149            with open(filepath, "w") as wrapper:
150                print("/* {} */".format(BANNER), file=wrapper)
151                print(file=wrapper)
152
153                # Architecture- or platform-dependent preprocessor flags can be
154                # defined here. Note: platform_preprocessor_flags can't be used
155                # because they are ignored by arc focus & buck project.
156
157                if condition is None:
158                    print("#include <%s>" % filename, file=wrapper)
159                else:
160                    # Include source file only if condition is satisfied
161                    print("#if %s" % condition, file=wrapper)
162                    print("#include <%s>" % filename, file=wrapper)
163                    print("#endif /* %s */" % condition, file=wrapper)
164
165    # update xnnpack_wrapper_defs.bzl file under the same folder
166    with open(os.path.join(os.path.dirname(__file__), "xnnpack_wrapper_defs.bzl"), 'w') as wrapper_defs:
167        print('"""', file=wrapper_defs)
168        print(BANNER, file=wrapper_defs)
169        print('"""', file=wrapper_defs)
170        for name in WRAPPER_SRC_NAMES:
171            print('\n' + name + ' = [', file=wrapper_defs)
172            for file_name in sources[name]:
173                print('    "xnnpack_wrappers/{}",'.format(file_name), file=wrapper_defs)
174            print(']', file=wrapper_defs)
175
176    # update xnnpack_src_defs.bzl file under the same folder
177    with open(os.path.join(os.path.dirname(__file__), "xnnpack_src_defs.bzl"), 'w') as src_defs:
178        print('"""', file=src_defs)
179        print(BANNER, file=src_defs)
180        print('"""', file=src_defs)
181        for name in SRC_NAMES:
182            print('\n' + name + ' = [', file=src_defs)
183            for file_name in sources[name]:
184                print('    "XNNPACK/src/{}",'.format(file_name), file=src_defs)
185            print(']', file=src_defs)
186
187
188def main(argv):
189    if argv is None or len(argv) == 0:
190        gen_wrappers(".")
191    else:
192        gen_wrappers(argv[0])
193
194# The first argument is the place where the "xnnpack_wrappers" folder will be created.
195# Run it without arguments will generate "xnnpack_wrappers" in the current path.
196# The two .bzl files will always be generated in the current path.
197if __name__ == "__main__":
198    main(sys.argv[1:])
199