xref: /aosp_15_r20/external/XNNPACK/tools/generate-window-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python
2# Copyright 2022 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16from primes import next_prime
17import xngen
18import xnncommon
19
20
21parser = argparse.ArgumentParser(description='Window microkernel test generator')
22parser.add_argument("-s", "--spec", metavar="FILE", required=True,
23                    help="Specification (YAML) file")
24parser.add_argument("-o", "--output", metavar="FILE", required=True,
25                    help='Output (C++ source) file')
26parser.set_defaults(defines=list())
27
28
29def split_ukernel_name(name):
30  shift = 0
31  row_tile = 1
32  match = re.fullmatch(r"xnn_s16_window(_shift(\d+))?_ukernel__(.+)_x(\d+)", name)
33  assert match is not None
34  if match.group(2):
35    shift = int(match.group(2))
36  batch_tile = int(match.group(4))
37
38  arch, isa = xnncommon.parse_target_name(target_name=match.group(3))
39  return shift, row_tile, batch_tile, arch, isa
40
41
42WINDOW_TEST_TEMPLATE = """\
43TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) {
44  $if ISA_CHECK:
45    ${ISA_CHECK};
46  WindowMicrokernelTester()
47    .rows(1)
48    .batch(${BATCH_TILE})
49    .shift(${SHIFT})
50    .Test(${", ".join(TEST_ARGS)});
51}
52
53$if BATCH_TILE > 1:
54  TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) {
55    $if ISA_CHECK:
56      ${ISA_CHECK};
57    for (size_t batch = ${BATCH_TILE*2}; batch < ${BATCH_TILE*10}; batch += ${BATCH_TILE}) {
58      WindowMicrokernelTester()
59        .batch(batch)
60        .shift(${SHIFT})
61        .Test(${", ".join(TEST_ARGS)});
62    }
63  }
64
65  TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) {
66    $if ISA_CHECK:
67      ${ISA_CHECK};
68    for (size_t batch = 1; batch < ${BATCH_TILE}; batch++) {
69      WindowMicrokernelTester()
70        .batch(batch)
71        .shift(${SHIFT})
72        .Test(${", ".join(TEST_ARGS)});
73    }
74  }
75
76TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) {
77  $if ISA_CHECK:
78    ${ISA_CHECK};
79  for (size_t batch = ${BATCH_TILE+1}; batch < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch++) {
80    WindowMicrokernelTester()
81      .batch(batch)
82      .shift(${SHIFT})
83      .Test(${", ".join(TEST_ARGS)});
84  }
85}
86
87$if ROW_TILE > 1:
88  TEST(${TEST_NAME}, rows_lt_${ROW_TILE}) {
89    $if ISA_CHECK:
90      ${ISA_CHECK};
91    for (size_t rows = 1; rows < ${ROW_TILE}; rows++) {
92      for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) {
93        WindowMicrokernelTester()
94          .rows(rows)
95          .batch(batch)
96          .shift(${SHIFT})
97          .Test(${", ".join(TEST_ARGS)});
98      }
99    }
100  }
101
102  TEST(${TEST_NAME}, rows_div_${ROW_TILE}) {
103    $if ISA_CHECK:
104      ${ISA_CHECK};
105    for (size_t rows = ${ROW_TILE*2}; rows <= ${ROW_TILE*4}; rows += ${ROW_TILE}) {
106      for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) {
107        WindowMicrokernelTester()
108          .rows(rows)
109          .batch(batch)
110          .shift(${SHIFT})
111          .Test(${", ".join(TEST_ARGS)});
112      }
113    }
114  }
115
116TEST(${TEST_NAME}, rows_gt_${ROW_TILE}) {
117  $if ISA_CHECK:
118    ${ISA_CHECK};
119  for (size_t rows = ${ROW_TILE+1}; rows < ${ROW_TILE*2}; rows++) {
120    for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) {
121      WindowMicrokernelTester()
122        .rows(rows)
123        .batch(batch)
124        .shift(${SHIFT})
125        .Test(${", ".join(TEST_ARGS)});
126    }
127  }
128}
129
130TEST(${TEST_NAME}, inplace) {
131  $if ISA_CHECK:
132    ${ISA_CHECK};
133  for (size_t rows = 1; rows <= ${ROW_TILE*3}; rows += ${max(1, ROW_TILE-1)}) {
134    for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) {
135      WindowMicrokernelTester()
136        .rows(rows)
137        .batch(batch)
138        .shift(${SHIFT})
139        .inplace(true)
140        .iterations(1)
141        .Test(${", ".join(TEST_ARGS)});
142    }
143  }
144}
145
146$if SHIFT == 0:
147  TEST(${TEST_NAME}, shift) {
148    $if ISA_CHECK:
149      ${ISA_CHECK};
150    for (uint32_t shift = 0; shift < 32; shift++) {
151      WindowMicrokernelTester()
152        .rows(${ROW_TILE})
153        .batch(${BATCH_TILE})
154        .shift(shift)
155        .Test(${", ".join(TEST_ARGS)});
156    }
157  }
158
159"""
160
161
162def generate_test_cases(ukernel, shift, row_tile, batch_tile, isa):
163  """Generates all tests cases for a Window micro-kernel.
164
165  Args:
166    ukernel: C name of the micro-kernel function.
167    shift: Shift by constant value.
168    row_tile: Number of rows (pixels) processed per one iteration of the outer
169              loop of the micro-kernel.
170    batch_tile: Number of batch processed per one iteration of the inner
171                  loop of the micro-kernel.
172    isa: instruction set required to run the micro-kernel. Generated unit test
173         will skip execution if the host processor doesn't support this ISA.
174
175  Returns:
176    Code for the test case.
177  """
178  _, test_name = ukernel.split("_", 1)
179  _, datatype, ukernel_type, _ = ukernel.split("_", 3)
180  return xngen.preprocess(WINDOW_TEST_TEMPLATE, {
181      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
182      "TEST_ARGS": [ukernel],
183      "DATATYPE": datatype,
184      "SHIFT": shift,
185      "ROW_TILE": row_tile,
186      "BATCH_TILE": batch_tile,
187      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
188      "next_prime": next_prime,
189    })
190
191
192def main(args):
193  options = parser.parse_args(args)
194
195  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
196    spec_yaml = yaml.safe_load(spec_file)
197    if not isinstance(spec_yaml, list):
198      raise ValueError("expected a list of micro-kernels in the spec")
199
200    tests = """\
201// Copyright 2022 Google LLC
202//
203// This source code is licensed under the BSD-style license found in the
204// LICENSE file in the root directory of this source tree.
205//
206// Auto-generated file. Do not edit!
207//   Specification: {specification}
208//   Generator: {generator}
209
210
211#include <gtest/gtest.h>
212
213#include <xnnpack/common.h>
214#include <xnnpack/isa-checks.h>
215
216#include <xnnpack/window.h>
217#include "window-microkernel-tester.h"
218""".format(specification=options.spec, generator=sys.argv[0])
219
220    for ukernel_spec in spec_yaml:
221      name = ukernel_spec["name"]
222      shift, row_tile, batch_tile, arch, isa = split_ukernel_name(name)
223
224      # specification can override architecture
225      arch = ukernel_spec.get("arch", arch)
226
227      test_case = generate_test_cases(name, shift, row_tile, batch_tile, isa)
228      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
229
230    txt_changed = True
231    if os.path.exists(options.output):
232      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
233        txt_changed = output_file.read() != tests
234
235    if txt_changed:
236      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
237        output_file.write(tests)
238
239
240if __name__ == "__main__":
241  main(sys.argv[1:])
242