xref: /aosp_15_r20/external/XNNPACK/tools/generate-pack-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import os
11import sys
12import yaml
13
14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15from primes import next_prime
16import xngen
17import xnncommon
18
19
20parser = argparse.ArgumentParser(description='XNNPACK generator')
21parser.add_argument("-s", "--spec", metavar="FILE", required=True,
22                    help="Spec (YAML) file")
23parser.add_argument("-o", "--output", metavar="FILE", required=True,
24                    help='Output (C++ source) file')
25parser.set_defaults(defines=list())
26
27
28def split_ukernel_name(name):
29  common_name, target_name = name.split("__", 1)
30  common_parts = common_name.split("_")
31  param_spec = common_parts[-1].split("x")
32  mr = int(param_spec[0])
33  arch, isa = xnncommon.parse_target_name(target_name)
34  return mr, arch, isa
35
36
37PACK_TEST_CODE = """\
38TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
39  $if ISA_CHECK:
40    ${ISA_CHECK};
41  PackMicrokernelTester()
42    .mr(${MR})
43    .m(${MR})
44    .k(${KBLOCK})
45    .Test(${UKERNEL_NAME});
46}
47
48TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
49  $if ISA_CHECK:
50    ${ISA_CHECK};
51  for (size_t m = 1; m <= ${MR}; m++) {
52    PackMicrokernelTester()
53      .mr(${MR})
54      .m(m)
55      .k(${KBLOCK})
56      .Test(${UKERNEL_NAME});
57  }
58}
59
60$if KBLOCK != 1:
61  TEST(${TEST_NAME}, k_lt_${KBLOCK}) {
62    $if ISA_CHECK:
63      ${ISA_CHECK};
64    for (size_t k = 1; k < ${KBLOCK}; k++) {
65      PackMicrokernelTester()
66        .mr(${MR})
67        .m(${MR})
68        .k(k)
69        .Test(${UKERNEL_NAME});
70    }
71  }
72
73  TEST(${TEST_NAME}, k_lt_${KBLOCK}_subtile) {
74    $if ISA_CHECK:
75      ${ISA_CHECK};
76    for (size_t k = 1; k < ${KBLOCK}; k++) {
77      for (size_t m = 1; m <= ${MR}; m++) {
78        PackMicrokernelTester()
79          .mr(${MR})
80          .m(m)
81          .k(k)
82          .Test(${UKERNEL_NAME});
83      }
84    }
85  }
86
87TEST(${TEST_NAME}, k_gt_${KBLOCK}) {
88  $if ISA_CHECK:
89    ${ISA_CHECK};
90  for (size_t k = ${KBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
91    PackMicrokernelTester()
92      .mr(${MR})
93      .m(${MR})
94      .k(k)
95      .Test(${UKERNEL_NAME});
96  }
97}
98
99TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) {
100  $if ISA_CHECK:
101    ${ISA_CHECK};
102  for (size_t k = ${KBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
103    for (size_t m = 1; m <= ${MR}; m++) {
104      PackMicrokernelTester()
105        .mr(${MR})
106        .m(m)
107        .k(k)
108        .Test(${UKERNEL_NAME});
109    }
110  }
111}
112
113$if KBLOCK > 1:
114  TEST(${TEST_NAME}, k_div_${KBLOCK}) {
115    $if ISA_CHECK:
116      ${ISA_CHECK};
117    for (size_t k = ${KBLOCK * 2}; k < ${KBLOCK * 10}; k += ${KBLOCK}) {
118      PackMicrokernelTester()
119        .mr(${MR})
120        .m(${MR})
121        .k(k)
122        .Test(${UKERNEL_NAME});
123    }
124  }
125
126  TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
127    $if ISA_CHECK:
128      ${ISA_CHECK};
129    for (size_t k = ${KBLOCK * 2}; k < ${KBLOCK * 10}; k += ${KBLOCK}) {
130      for (size_t m = 1; m <= ${MR}; m++) {
131        PackMicrokernelTester()
132          .mr(${MR})
133          .m(m)
134          .k(k)
135          .Test(${UKERNEL_NAME});
136      }
137    }
138  }
139
140TEST(${TEST_NAME}, strided_x) {
141  $if ISA_CHECK:
142    ${ISA_CHECK};
143  for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
144    PackMicrokernelTester()
145      .mr(${MR})
146      .m(${MR})
147      .k(k)
148      .x_stride(${next_prime(KBLOCK * 5 + 1)})
149      .Test(${UKERNEL_NAME});
150  }
151}
152"""
153
154
155def generate_test_cases(ukernel, mr, k_block, isa):
156  """Generates all tests cases for a GEMM micro-kernel.
157
158  Args:
159    ukernel: C name of the micro-kernel function.
160    mr: MR parameter of the PACK micro-kernel.
161    k_block: Number of K values processed per one iteration of the main loop of
162             the micro-kernel.
163    isa: instruction set required to run the micro-kernel. Generated unit test
164         will skip execution if the host processor doesn't support this ISA.
165
166  Returns:
167    Code for the test case.
168  """
169  _, test_name = ukernel.split("_", 1)
170  _, datatype, ukernel_type, _ = ukernel.split("_", 3)
171  test_args = [ukernel]
172  return xngen.preprocess(PACK_TEST_CODE, {
173      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
174      "UKERNEL_TYPE": ukernel_type.upper(),
175      "UKERNEL_NAME": ukernel,
176      "DATATYPE": datatype,
177      "MR": mr,
178      "KBLOCK": k_block,
179      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
180      "next_prime": next_prime,
181    })
182
183
184def main(args):
185  options = parser.parse_args(args)
186
187  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
188    spec_yaml = yaml.safe_load(spec_file)
189    if not isinstance(spec_yaml, list):
190      raise ValueError("expected a list of micro-kernels in the spec")
191
192    tests = """\
193// Copyright 2019 Google LLC
194//
195// This source code is licensed under the BSD-style license found in the
196// LICENSE file in the root directory of this source tree.
197//
198// Auto-generated file. Do not edit!
199//   Specification: {specification}
200//   Generator: {generator}
201
202
203#include <gtest/gtest.h>
204
205#include <xnnpack/common.h>
206#include <xnnpack/isa-checks.h>
207
208#include <xnnpack/packx.h>
209#include "pack-microkernel-tester.h"
210""".format(specification=options.spec, generator=sys.argv[0])
211
212    for ukernel_spec in spec_yaml:
213      name = ukernel_spec["name"]
214      k_block = int(ukernel_spec["k-block"])
215      mr, arch, isa = split_ukernel_name(name)
216
217      # specification can override architecture
218      arch = ukernel_spec.get("arch", arch)
219
220      test_case = generate_test_cases(name, mr, k_block, isa)
221      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
222
223    txt_changed = True
224    if os.path.exists(options.output):
225      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
226        txt_changed = output_file.read() != tests
227
228    if txt_changed:
229      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
230        output_file.write(tests)
231
232
233if __name__ == "__main__":
234  main(sys.argv[1:])
235