xref: /aosp_15_r20/external/XNNPACK/tools/generate-transpose-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python
2# Copyright 2021 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16import xngen
17import xnncommon
18
19parser = argparse.ArgumentParser(
20    description="Matrix transpose microkernel test generator")
21parser.add_argument(
22    "-s",
23    "--spec",
24    metavar="FILE",
25    required=True,
26    help="Specification (YAML) file")
27parser.add_argument(
28    "-o",
29    "--output",
30    metavar="FILE",
31    required=True,
32    help="Output (C++ source) file")
33parser.set_defaults(defines=list())
34
35
36def split_ukernel_name(name):
37  match = re.fullmatch(r"xnn_x(.+)_transpose(v|c)_ukernel__(\d+)x(\d+)_(.+)", name)
38  if match is None:
39    raise ValueError("Unexpected microkernel name: " + name)
40  if match.group(1) == 'x':
41    element_size = None
42  else:
43    element_size = int(int(match.group(1)) / 8)
44  tile_height = int(match.group(3))
45  tile_width = int(match.group(4))
46
47  arch, isa = xnncommon.parse_target_name(target_name=match.group(5))
48  return tile_height, tile_width, element_size, arch, isa
49
50
51TRANSPOSE_TEST_TEMPLATE = """\
52TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}) {
53  $if ISA_CHECK:
54    ${ISA_CHECK};
55  TransposeMicrokernelTester()
56    .input_stride(${TILE_WIDTH * 2})
57    .output_stride(${TILE_HEIGHT * 2})
58    .block_width(${TILE_WIDTH})
59    .block_height(${TILE_HEIGHT})
60    .element_size(${ELEMENT_SIZE})
61    .iterations(1)
62    .Test(${KERNEL});
63}
64
65TEST(${TEST_NAME}, bh_1_${TILE_HEIGHT * 2}_bw_1_${TILE_WIDTH * 2}) {
66  $if ISA_CHECK:
67    ${ISA_CHECK};
68  for(size_t i = 1; i <= ${TILE_HEIGHT * 2}; ++i){
69    for(size_t j = 1; j <= ${TILE_WIDTH * 2}; ++j){
70      TransposeMicrokernelTester()
71        .input_stride(j * 3)
72        .output_stride(i * 7)
73        .block_width(j)
74        .block_height(i)
75        .element_size(${ELEMENT_SIZE})
76        .iterations(1)
77        .Test(${KERNEL});
78    }
79  }
80}
81
82TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH * 2}) {
83  $if ISA_CHECK:
84    ${ISA_CHECK};
85  TransposeMicrokernelTester()
86    .input_stride(${TILE_WIDTH * 2})
87    .output_stride(${TILE_HEIGHT})
88    .block_width(${TILE_WIDTH * 2})
89    .block_height(${TILE_HEIGHT})
90    .element_size(${ELEMENT_SIZE})
91    .iterations(1)
92    .Test(${KERNEL});
93}
94
95TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH + 1}_${TILE_WIDTH * 2}) {
96  $if ISA_CHECK:
97    ${ISA_CHECK};
98  for(size_t i = ${TILE_WIDTH + 1}; i < ${TILE_WIDTH * 2}; ++i){
99    TransposeMicrokernelTester()
100      .input_stride(i)
101      .output_stride(${TILE_HEIGHT * 2})
102      .block_width(i)
103      .block_height(${TILE_HEIGHT})
104      .element_size(${ELEMENT_SIZE})
105      .iterations(1)
106      .Test(${KERNEL});
107  }
108}
109
110TEST(${TEST_NAME}, bh_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH + 1}_${TILE_WIDTH * 2}) {
111  $if ISA_CHECK:
112    ${ISA_CHECK};
113  for(size_t i = ${TILE_WIDTH + 1}; i < ${TILE_WIDTH * 2}; ++i){
114    TransposeMicrokernelTester()
115      .input_stride(i)
116      .output_stride(${TILE_HEIGHT * 2})
117      .block_width(i)
118      .block_height(${TILE_HEIGHT * 2})
119      .element_size(${ELEMENT_SIZE})
120      .iterations(1)
121      .Test(${KERNEL});
122  }
123}
124
125TEST(${TEST_NAME}, bh_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH}) {
126  $if ISA_CHECK:
127    ${ISA_CHECK};
128  TransposeMicrokernelTester()
129    .input_stride(${TILE_WIDTH})
130    .output_stride(${TILE_HEIGHT * 3 + 4})
131    .block_width(${TILE_WIDTH})
132    .block_height(${TILE_HEIGHT * 2})
133    .element_size(${ELEMENT_SIZE})
134    .iterations(1)
135    .Test(${KERNEL});
136}
137
138TEST(${TEST_NAME}, bh_${TILE_HEIGHT + 1}_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH}){
139  $if ISA_CHECK:
140    ${ISA_CHECK};
141  for(size_t i = ${TILE_HEIGHT + 1}; i < ${TILE_HEIGHT * 2}; ++i){
142    TransposeMicrokernelTester()
143      .input_stride(${TILE_WIDTH + 17})
144      .output_stride(i)
145      .block_width(${TILE_WIDTH + 3})
146      .block_height(i)
147      .element_size(${ELEMENT_SIZE})
148      .iterations(1)
149      .Test(${KERNEL});
150  }
151}
152
153TEST(${TEST_NAME}, bh_${TILE_HEIGHT + 1}_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH * 2}){
154  $if ISA_CHECK:
155    ${ISA_CHECK};
156  for(size_t i = ${TILE_HEIGHT + 1}; i < ${TILE_HEIGHT * 2}; ++i){
157    TransposeMicrokernelTester()
158      .input_stride(${TILE_WIDTH * 2})
159      .output_stride(i)
160      .block_width(${TILE_WIDTH * 2})
161      .block_height(i)
162      .element_size(${ELEMENT_SIZE})
163      .iterations(1)
164      .Test(${KERNEL});
165  }
166}
167
168TEST(${TEST_NAME}, bh_${TILE_HEIGHT + 1}_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH + 1}_${TILE_WIDTH * 2}) {
169  $if ISA_CHECK:
170    ${ISA_CHECK};
171  for(size_t i = ${TILE_HEIGHT + 1}; i < ${TILE_HEIGHT * 2}; ++i){
172    for(size_t j = ${TILE_WIDTH + 1}; j < ${TILE_WIDTH * 2}; ++j){
173      TransposeMicrokernelTester()
174        .input_stride(j)
175        .output_stride(i)
176        .block_width(j)
177        .block_height(i)
178        .element_size(${ELEMENT_SIZE})
179        .iterations(1)
180        .Test(${KERNEL});
181    }
182  }
183}
184
185TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}_is_${TILE_WIDTH * 2}) {
186  $if ISA_CHECK:
187    ${ISA_CHECK};
188  TransposeMicrokernelTester()
189    .input_stride(${TILE_WIDTH * 2})
190    .output_stride(${TILE_HEIGHT})
191    .block_width(${TILE_WIDTH})
192    .block_height(${TILE_HEIGHT})
193    .element_size(${ELEMENT_SIZE})
194    .iterations(1)
195    .Test(${KERNEL});
196}
197
198TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}_os_${TILE_HEIGHT * 2}) {
199  $if ISA_CHECK:
200    ${ISA_CHECK};
201  TransposeMicrokernelTester()
202    .input_stride(${TILE_WIDTH})
203    .output_stride(${TILE_HEIGHT * 2})
204    .block_width(${TILE_WIDTH})
205    .block_height(${TILE_HEIGHT})
206    .element_size(${ELEMENT_SIZE})
207    .iterations(1)
208    .Test(${KERNEL});
209}
210
211TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}_is_${TILE_WIDTH * 2}_os_${TILE_HEIGHT * 2}) {
212  $if ISA_CHECK:
213    ${ISA_CHECK};
214  TransposeMicrokernelTester()
215    .input_stride(${TILE_WIDTH * 2})
216    .output_stride(${TILE_HEIGHT * 2})
217    .block_width(${TILE_WIDTH})
218    .block_height(${TILE_HEIGHT})
219    .element_size(${ELEMENT_SIZE})
220    .iterations(1)
221    .Test(${KERNEL});
222}
223
224TEST(${TEST_NAME}, bh_${TILE_HEIGHT * 17}_bw_${TILE_WIDTH * 19}_ies_${ELEMENT_SIZE + 11}) {
225  $if ISA_CHECK:
226    ${ISA_CHECK};
227  TransposeMicrokernelTester()
228    .input_stride(${TILE_WIDTH * 19})
229    .output_stride(${TILE_HEIGHT * 17})
230    .block_width(${TILE_WIDTH * 19})
231    .block_height(${TILE_HEIGHT * 17})
232    .element_size(${ELEMENT_SIZE})
233    .input_element_stride(${ELEMENT_SIZE + 11})
234    .iterations(1)
235    .Test(${KERNEL});
236}
237
238TEST(${TEST_NAME}, bh_${TILE_HEIGHT * 3}_bw_${TILE_WIDTH * 5}_oes_${ELEMENT_SIZE + 11}) {
239  $if ISA_CHECK:
240    ${ISA_CHECK};
241  TransposeMicrokernelTester()
242    .input_stride(${TILE_WIDTH * 5})
243    .output_stride(${TILE_HEIGHT * 3})
244    .block_width(${TILE_WIDTH * 5})
245    .block_height(${TILE_HEIGHT * 3})
246    .element_size(${ELEMENT_SIZE})
247    .output_element_stride(${ELEMENT_SIZE + 11})
248    .iterations(1)
249    .Test(${KERNEL});
250}
251
252TEST(${TEST_NAME}, bh_${TILE_HEIGHT * 7}_bw_${TILE_WIDTH * 23}_ies_${ELEMENT_SIZE + 17}_oes_${ELEMENT_SIZE + 13}) {
253  $if ISA_CHECK:
254    ${ISA_CHECK};
255  TransposeMicrokernelTester()
256    .input_stride(${TILE_WIDTH * 23 + 5})
257    .output_stride(${TILE_HEIGHT * 7 + 6})
258    .block_width(${TILE_WIDTH * 23})
259    .block_height(${TILE_HEIGHT * 7})
260    .element_size(${ELEMENT_SIZE})
261    .input_element_stride(${ELEMENT_SIZE + 17})
262    .output_element_stride(${ELEMENT_SIZE + 13})
263    .iterations(1)
264    .Test(${KERNEL});
265}
266"""
267
268
269def generate_test_cases(ukernel, tile_height, tile_width, element_size, isa):
270  """Generates all tests cases for a Vector Convert Operation micro-kernel.
271
272  Args:
273    ukernel: C name of the micro-kernel function.
274    tile_height: Number of vertical elements processed by the ukernel.
275    tile_width: Number of horizontal elements processed by the ukernel.
276    element_size: Size of each element in bytes.
277    isa: instruction set required to run the micro-kernel. Generated unit test
278      will skip execution if the host processor doesn't support this ISA.
279
280  Returns:
281    Code for the test case.
282  """
283  _, test_name = ukernel.split("_", 1)
284  test_args = [ukernel]
285  return xngen.preprocess(
286      TRANSPOSE_TEST_TEMPLATE, {
287          "TEST_NAME": test_name.upper().replace("UKERNEL_", "") + '_' + str(element_size),
288          "KERNEL": ukernel,
289          "TILE_HEIGHT": tile_height,
290          "TILE_WIDTH": tile_width,
291          "ELEMENT_SIZE": element_size,
292          "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
293      })
294
295def generate_memcpy_test_cases(ukernel, tile_height, isa):
296  """Generates all tests cases for a Vector Convert Operation micro-kernel.
297
298  Args:
299    ukernel: C name of the micro-kernel function.
300    tile_height: Number of vertical elements processed by the ukernel.
301    isa: instruction set required to run the micro-kernel. Generated unit test
302      will skip execution if the host processor doesn't support this ISA.
303
304  Returns:
305    Code for the test case.
306  """
307  _, test_name = ukernel.split("_", 1)
308  test_args = [ukernel]
309  return xngen.preprocess(
310      TRANSPOSE_MEMCPY_TEST_TEMPLATE, {
311          "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
312          "KERNEL": ukernel,
313          "TILE_HEIGHT": tile_height,
314          "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
315      })
316
317
318def main(args):
319  options = parser.parse_args(args)
320
321  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
322    spec_yaml = yaml.safe_load(spec_file)
323    if not isinstance(spec_yaml, list):
324      raise ValueError("expected a list of micro-kernels in the spec")
325
326    tests = """\
327// Copyright 2021 Google LLC
328//
329// This source code is licensed under the BSD-style license found in the
330// LICENSE file in the root directory of this source tree.
331//
332// Auto-generated file. Do not edit!
333//   Specification: {specification}
334//   Generator: {generator}
335
336
337#include <gtest/gtest.h>
338
339#include <xnnpack/common.h>
340#include <xnnpack/isa-checks.h>
341
342#include <xnnpack/transpose.h>
343#include "transpose-microkernel-tester.h"
344""".format(
345    specification=options.spec, generator=sys.argv[0])
346
347    for ukernel_spec in spec_yaml:
348      name = ukernel_spec["name"]
349      tile_height, tile_width, element_size, arch, isa = split_ukernel_name(name)
350
351      # specification can override architecture
352      arch = ukernel_spec.get("arch", arch)
353
354      if element_size is not None:
355        test_case = generate_test_cases(name, tile_height, tile_width, element_size, isa)
356      else:
357        test_case = generate_test_cases(name, tile_height, tile_width, 1, isa)
358        test_case += generate_test_cases(name, tile_height, tile_width, 3, isa)
359        test_case += generate_test_cases(name, tile_height, tile_width, 5, isa)
360      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
361
362    txt_changed = True
363    if os.path.exists(options.output):
364      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
365        txt_changed = output_file.read() != tests
366
367    if txt_changed:
368      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
369        output_file.write(tests)
370
371
372if __name__ == "__main__":
373  main(sys.argv[1:])
374