1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16import xngen 17import xnncommon 18 19 20parser = argparse.ArgumentParser( 21 description='Vector unary operation microkernel test generator') 22parser.add_argument("-s", "--spec", metavar="FILE", required=True, 23 help="Specification (YAML) file") 24parser.add_argument("-o", "--output", metavar="FILE", required=True, 25 help='Output (C++ source) file') 26parser.set_defaults(defines=list()) 27 28 29def split_ukernel_name(name): 30 match = re.fullmatch(r"xnn_(s8|u8|f16|f32|u32|u64)(_(s8|u8|f16|f32|u32|u64))*_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt|sqrtshift)_(fact_)?ukernel__(.+)_x(\d+)", name) 31 if match is None: 32 raise ValueError("Unexpected microkernel name: " + name) 33 op_type = { 34 "abs": "Abs", 35 "clamp": "Clamp", 36 "elu": "ELU", 37 "hswish": "HardSwish", 38 "lrelu": "LeakyReLU", 39 "neg": "Negate", 40 "relu": "ReLU", 41 "rndd": "RoundDown", 42 "rndne": "RoundToNearestEven", 43 "rndz": "RoundTowardsZero", 44 "rndu": "RoundUp", 45 "sigmoid": "Sigmoid", 46 "sqr": "Square", 47 "sqrt": "SquareRoot", 48 "sqrtshift": "SquareRootShift", 49 }[match.group(4)] 50 batch_tile = int(match.group(7)) 51 52 arch, isa = xnncommon.parse_target_name(target_name=match.group(6)) 53 return op_type, batch_tile, arch, isa 54 55 56BINOP_TEST_TEMPLATE = """\ 57TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) { 58 $if ISA_CHECK: 59 ${ISA_CHECK}; 60 VUnaryMicrokernelTester() 61 .batch_size(${BATCH_TILE}) 62 .Test(${", ".join(TEST_ARGS)}); 63} 64 65$if BATCH_TILE > 1: 66 TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) { 67 $if ISA_CHECK: 68 ${ISA_CHECK}; 69 for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) { 70 VUnaryMicrokernelTester() 71 .batch_size(batch_size) 72 .Test(${", ".join(TEST_ARGS)}); 73 } 74 } 75 76 TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) { 77 $if ISA_CHECK: 78 ${ISA_CHECK}; 79 for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) { 80 VUnaryMicrokernelTester() 81 .batch_size(batch_size) 82 .Test(${", ".join(TEST_ARGS)}); 83 } 84 } 85 86TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) { 87 $if ISA_CHECK: 88 ${ISA_CHECK}; 89 for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) { 90 VUnaryMicrokernelTester() 91 .batch_size(batch_size) 92 .Test(${", ".join(TEST_ARGS)}); 93 } 94} 95 96$if OP_TYPE != "SquareRootShift": 97 TEST(${TEST_NAME}, inplace) { 98 $if ISA_CHECK: 99 ${ISA_CHECK}; 100 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 101 VUnaryMicrokernelTester() 102 .batch_size(batch_size) 103 .inplace(true) 104 .Test(${", ".join(TEST_ARGS)}); 105 } 106 } 107 108$if OP_TYPE == "Clamp": 109 TEST(${TEST_NAME}, qmin) { 110 $if ISA_CHECK: 111 ${ISA_CHECK}; 112 for (uint8_t qmin = 1; qmin < 255; qmin++) { 113 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 114 VUnaryMicrokernelTester() 115 .batch_size(batch_size) 116 .qmin(qmin) 117 .Test(${", ".join(TEST_ARGS)}); 118 } 119 } 120 } 121 122 TEST(${TEST_NAME}, qmax) { 123 $if ISA_CHECK: 124 ${ISA_CHECK}; 125 for (uint8_t qmax = 1; qmax < 255; qmax++) { 126 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 127 VUnaryMicrokernelTester() 128 .batch_size(batch_size) 129 .qmax(qmax) 130 .Test(${", ".join(TEST_ARGS)}); 131 } 132 } 133 } 134 135$if OP_TYPE == "ELU": 136 TEST(${TEST_NAME}, prescale) { 137 $if ISA_CHECK: 138 ${ISA_CHECK}; 139 for (float prescale : std::vector<float>({0.1f, 10.0f})) { 140 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 141 VUnaryMicrokernelTester() 142 .batch_size(batch_size) 143 .prescale(prescale) 144 .Test(${", ".join(TEST_ARGS)}); 145 } 146 } 147 } 148 149 TEST(${TEST_NAME}, alpha) { 150 $if ISA_CHECK: 151 ${ISA_CHECK}; 152 for (float alpha : std::vector<float>({0.3f, 3.0f})) { 153 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 154 VUnaryMicrokernelTester() 155 .batch_size(batch_size) 156 .alpha(alpha) 157 .Test(${", ".join(TEST_ARGS)}); 158 } 159 } 160 } 161 162 TEST(${TEST_NAME}, beta) { 163 $if ISA_CHECK: 164 ${ISA_CHECK}; 165 for (float beta : std::vector<float>({0.3f, 3.0f})) { 166 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 167 VUnaryMicrokernelTester() 168 .batch_size(batch_size) 169 .beta(beta) 170 .Test(${", ".join(TEST_ARGS)}); 171 } 172 } 173 } 174 175$if OP_TYPE == "LeakyReLU": 176 TEST(${TEST_NAME}, slope) { 177 $if ISA_CHECK: 178 ${ISA_CHECK}; 179 for (float slope : std::vector<float>({-0.7f, 0.3f, 1.3f})) { 180 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 181 VUnaryMicrokernelTester() 182 .batch_size(batch_size) 183 .slope(slope) 184 .Test(${", ".join(TEST_ARGS)}); 185 } 186 } 187 } 188 189$if OP_TYPE == "SquareRootShift": 190 TEST(${TEST_NAME}, shift) { 191 $if ISA_CHECK: 192 ${ISA_CHECK}; 193 for (uint32_t shift = 0; shift < 32; shift++) { 194 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 195 VUnaryMicrokernelTester() 196 .batch_size(batch_size) 197 .shift(shift) 198 .Test(${", ".join(TEST_ARGS)}); 199 } 200 } 201 } 202""" 203 204 205def generate_test_cases(ukernel, op_type, init_fn, batch_tile, isa): 206 """Generates all tests cases for a Vector Unary Operation micro-kernel. 207 208 Args: 209 ukernel: C name of the micro-kernel function. 210 op_type: Operation type. 211 init_fn: C name of the function to initialize microkernel parameters. 212 batch_tile: Number of batch elements processed per one iteration of the 213 inner loop of the micro-kernel. 214 isa: instruction set required to run the micro-kernel. Generated unit test 215 will skip execution if the host processor doesn't support this ISA. 216 217 Returns: 218 Code for the test case. 219 """ 220 _, test_name = ukernel.split("_", 1) 221 _, datatype, _ = ukernel.split("_", 2) 222 test_args = [ukernel] 223 if init_fn or op_type.startswith("Round"): 224 if op_type.startswith("Round"): 225 test_args.append("VUnaryMicrokernelTester::OpType::" + op_type) 226 if init_fn is not None: 227 test_args.append(init_fn) 228 return xngen.preprocess(BINOP_TEST_TEMPLATE, { 229 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 230 "TEST_ARGS": test_args, 231 "DATATYPE": datatype, 232 "BATCH_TILE": batch_tile, 233 "OP_TYPE": op_type, 234 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 235 }) 236 237 238def main(args): 239 options = parser.parse_args(args) 240 241 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 242 spec_yaml = yaml.safe_load(spec_file) 243 if not isinstance(spec_yaml, list): 244 raise ValueError("expected a list of micro-kernels in the spec") 245 246 tests = """\ 247// Copyright 2019 Google LLC 248// 249// This source code is licensed under the BSD-style license found in the 250// LICENSE file in the root directory of this source tree. 251// 252// Auto-generated file. Do not edit! 253// Specification: {specification} 254// Generator: {generator} 255 256 257#include <gtest/gtest.h> 258 259#include <xnnpack/common.h> 260#include <xnnpack/isa-checks.h> 261 262#include <xnnpack/vunary.h> 263#include "vunary-microkernel-tester.h" 264""".format(specification=options.spec, generator=sys.argv[0]) 265 266 for ukernel_spec in spec_yaml: 267 name = ukernel_spec["name"] 268 init_fn = ukernel_spec.get("init") 269 op_type, batch_tile, arch, isa = split_ukernel_name(name) 270 271 # specification can override architecture 272 arch = ukernel_spec.get("arch", arch) 273 274 test_case = generate_test_cases(name, op_type, init_fn, batch_tile, isa) 275 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 276 277 txt_changed = True 278 if os.path.exists(options.output): 279 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 280 txt_changed = output_file.read() != tests 281 282 if txt_changed: 283 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 284 output_file.write(tests) 285 286 287if __name__ == "__main__": 288 main(sys.argv[1:]) 289