xref: /aosp_15_r20/external/XNNPACK/tools/generate-vunary-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16import xngen
17import xnncommon
18
19
20parser = argparse.ArgumentParser(
21  description='Vector unary operation microkernel test generator')
22parser.add_argument("-s", "--spec", metavar="FILE", required=True,
23                    help="Specification (YAML) file")
24parser.add_argument("-o", "--output", metavar="FILE", required=True,
25                    help='Output (C++ source) file')
26parser.set_defaults(defines=list())
27
28
29def split_ukernel_name(name):
30  match = re.fullmatch(r"xnn_(s8|u8|f16|f32|u32|u64)(_(s8|u8|f16|f32|u32|u64))*_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt|sqrtshift)_(fact_)?ukernel__(.+)_x(\d+)", name)
31  if match is None:
32    raise ValueError("Unexpected microkernel name: " + name)
33  op_type = {
34    "abs": "Abs",
35    "clamp": "Clamp",
36    "elu": "ELU",
37    "hswish": "HardSwish",
38    "lrelu": "LeakyReLU",
39    "neg": "Negate",
40    "relu": "ReLU",
41    "rndd": "RoundDown",
42    "rndne": "RoundToNearestEven",
43    "rndz": "RoundTowardsZero",
44    "rndu": "RoundUp",
45    "sigmoid": "Sigmoid",
46    "sqr": "Square",
47    "sqrt": "SquareRoot",
48    "sqrtshift": "SquareRootShift",
49  }[match.group(4)]
50  batch_tile = int(match.group(7))
51
52  arch, isa = xnncommon.parse_target_name(target_name=match.group(6))
53  return op_type, batch_tile, arch, isa
54
55
56BINOP_TEST_TEMPLATE = """\
57TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) {
58  $if ISA_CHECK:
59    ${ISA_CHECK};
60  VUnaryMicrokernelTester()
61    .batch_size(${BATCH_TILE})
62    .Test(${", ".join(TEST_ARGS)});
63}
64
65$if BATCH_TILE > 1:
66  TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) {
67    $if ISA_CHECK:
68      ${ISA_CHECK};
69    for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) {
70      VUnaryMicrokernelTester()
71        .batch_size(batch_size)
72        .Test(${", ".join(TEST_ARGS)});
73    }
74  }
75
76  TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) {
77    $if ISA_CHECK:
78      ${ISA_CHECK};
79    for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) {
80      VUnaryMicrokernelTester()
81        .batch_size(batch_size)
82        .Test(${", ".join(TEST_ARGS)});
83    }
84  }
85
86TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) {
87  $if ISA_CHECK:
88    ${ISA_CHECK};
89  for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) {
90    VUnaryMicrokernelTester()
91      .batch_size(batch_size)
92      .Test(${", ".join(TEST_ARGS)});
93  }
94}
95
96$if OP_TYPE != "SquareRootShift":
97  TEST(${TEST_NAME}, inplace) {
98    $if ISA_CHECK:
99      ${ISA_CHECK};
100    for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
101      VUnaryMicrokernelTester()
102        .batch_size(batch_size)
103        .inplace(true)
104        .Test(${", ".join(TEST_ARGS)});
105    }
106  }
107
108$if OP_TYPE == "Clamp":
109  TEST(${TEST_NAME}, qmin) {
110    $if ISA_CHECK:
111      ${ISA_CHECK};
112    for (uint8_t qmin = 1; qmin < 255; qmin++) {
113      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
114        VUnaryMicrokernelTester()
115          .batch_size(batch_size)
116          .qmin(qmin)
117          .Test(${", ".join(TEST_ARGS)});
118      }
119    }
120  }
121
122  TEST(${TEST_NAME}, qmax) {
123    $if ISA_CHECK:
124      ${ISA_CHECK};
125    for (uint8_t qmax = 1; qmax < 255; qmax++) {
126      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
127        VUnaryMicrokernelTester()
128          .batch_size(batch_size)
129          .qmax(qmax)
130          .Test(${", ".join(TEST_ARGS)});
131      }
132    }
133  }
134
135$if OP_TYPE == "ELU":
136  TEST(${TEST_NAME}, prescale) {
137    $if ISA_CHECK:
138      ${ISA_CHECK};
139    for (float prescale : std::vector<float>({0.1f, 10.0f})) {
140      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
141        VUnaryMicrokernelTester()
142          .batch_size(batch_size)
143          .prescale(prescale)
144          .Test(${", ".join(TEST_ARGS)});
145      }
146    }
147  }
148
149  TEST(${TEST_NAME}, alpha) {
150    $if ISA_CHECK:
151      ${ISA_CHECK};
152    for (float alpha : std::vector<float>({0.3f, 3.0f})) {
153      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
154        VUnaryMicrokernelTester()
155          .batch_size(batch_size)
156          .alpha(alpha)
157          .Test(${", ".join(TEST_ARGS)});
158      }
159    }
160  }
161
162  TEST(${TEST_NAME}, beta) {
163    $if ISA_CHECK:
164      ${ISA_CHECK};
165    for (float beta : std::vector<float>({0.3f, 3.0f})) {
166      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
167        VUnaryMicrokernelTester()
168          .batch_size(batch_size)
169          .beta(beta)
170          .Test(${", ".join(TEST_ARGS)});
171      }
172    }
173  }
174
175$if OP_TYPE == "LeakyReLU":
176  TEST(${TEST_NAME}, slope) {
177    $if ISA_CHECK:
178      ${ISA_CHECK};
179    for (float slope : std::vector<float>({-0.7f, 0.3f, 1.3f})) {
180      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
181        VUnaryMicrokernelTester()
182          .batch_size(batch_size)
183          .slope(slope)
184          .Test(${", ".join(TEST_ARGS)});
185      }
186    }
187  }
188
189$if OP_TYPE == "SquareRootShift":
190  TEST(${TEST_NAME}, shift) {
191    $if ISA_CHECK:
192      ${ISA_CHECK};
193    for (uint32_t shift = 0; shift < 32; shift++) {
194      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
195        VUnaryMicrokernelTester()
196          .batch_size(batch_size)
197          .shift(shift)
198          .Test(${", ".join(TEST_ARGS)});
199      }
200    }
201  }
202"""
203
204
205def generate_test_cases(ukernel, op_type, init_fn, batch_tile, isa):
206  """Generates all tests cases for a Vector Unary Operation micro-kernel.
207
208  Args:
209    ukernel: C name of the micro-kernel function.
210    op_type: Operation type.
211    init_fn: C name of the function to initialize microkernel parameters.
212    batch_tile: Number of batch elements processed per one iteration of the
213                inner loop of the micro-kernel.
214    isa: instruction set required to run the micro-kernel. Generated unit test
215         will skip execution if the host processor doesn't support this ISA.
216
217  Returns:
218    Code for the test case.
219  """
220  _, test_name = ukernel.split("_", 1)
221  _, datatype, _ = ukernel.split("_", 2)
222  test_args = [ukernel]
223  if init_fn or op_type.startswith("Round"):
224    if op_type.startswith("Round"):
225      test_args.append("VUnaryMicrokernelTester::OpType::" + op_type)
226    if init_fn is not None:
227      test_args.append(init_fn)
228  return xngen.preprocess(BINOP_TEST_TEMPLATE, {
229      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
230      "TEST_ARGS": test_args,
231      "DATATYPE": datatype,
232      "BATCH_TILE": batch_tile,
233      "OP_TYPE": op_type,
234      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
235    })
236
237
238def main(args):
239  options = parser.parse_args(args)
240
241  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
242    spec_yaml = yaml.safe_load(spec_file)
243    if not isinstance(spec_yaml, list):
244      raise ValueError("expected a list of micro-kernels in the spec")
245
246    tests = """\
247// Copyright 2019 Google LLC
248//
249// This source code is licensed under the BSD-style license found in the
250// LICENSE file in the root directory of this source tree.
251//
252// Auto-generated file. Do not edit!
253//   Specification: {specification}
254//   Generator: {generator}
255
256
257#include <gtest/gtest.h>
258
259#include <xnnpack/common.h>
260#include <xnnpack/isa-checks.h>
261
262#include <xnnpack/vunary.h>
263#include "vunary-microkernel-tester.h"
264""".format(specification=options.spec, generator=sys.argv[0])
265
266    for ukernel_spec in spec_yaml:
267      name = ukernel_spec["name"]
268      init_fn = ukernel_spec.get("init")
269      op_type, batch_tile, arch, isa = split_ukernel_name(name)
270
271      # specification can override architecture
272      arch = ukernel_spec.get("arch", arch)
273
274      test_case = generate_test_cases(name, op_type, init_fn, batch_tile, isa)
275      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
276
277    txt_changed = True
278    if os.path.exists(options.output):
279      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
280        txt_changed = output_file.read() != tests
281
282    if txt_changed:
283      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
284        output_file.write(tests)
285
286
287if __name__ == "__main__":
288  main(sys.argv[1:])
289