1#!/usr/bin/env python 2# Copyright 2021 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16import xngen 17import xnncommon 18 19 20parser = argparse.ArgumentParser( 21 description='Vector conversion operation microkernel test generator') 22parser.add_argument("-s", "--spec", metavar="FILE", required=True, 23 help="Specification (YAML) file") 24parser.add_argument("-o", "--output", metavar="FILE", required=True, 25 help='Output (C++ source) file') 26parser.set_defaults(defines=list()) 27 28 29def split_ukernel_name(name): 30 match = re.fullmatch(r"xnn_(f16|f32|qs8|qu8)(_(f16|f32|qs8|qu8))?_vcvt_ukernel__(.+)_x(\d+)", name) 31 if match is None: 32 raise ValueError("Unexpected microkernel name: " + name) 33 34 input_datatype = match.group(1) 35 if match.group(2): 36 output_datatype = match.group(3) 37 else: 38 output_datatype = input_datatype 39 batch_tile = int(match.group(5)) 40 41 arch, isa = xnncommon.parse_target_name(target_name=match.group(4)) 42 return input_datatype, output_datatype, batch_tile, arch, isa 43 44 45CVT_TEST_TEMPLATE = """\ 46TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) { 47 $if ISA_CHECK: 48 ${ISA_CHECK}; 49 VCvtMicrokernelTester() 50 .batch_size(${BATCH_TILE}) 51 $if OUTPUT_DATATYPE == "QS8": 52 .qmin(std::numeric_limits<int8_t>::min()) 53 .qmax(std::numeric_limits<int8_t>::max()) 54 $elif OUTPUT_DATATYPE == "QU8": 55 .qmin(std::numeric_limits<uint8_t>::min()) 56 .qmax(std::numeric_limits<uint8_t>::max()) 57 .Test(${", ".join(TEST_ARGS)}); 58} 59 60$if BATCH_TILE > 1: 61 TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) { 62 $if ISA_CHECK: 63 ${ISA_CHECK}; 64 for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) { 65 VCvtMicrokernelTester() 66 .batch_size(batch_size) 67 $if OUTPUT_DATATYPE == "QS8": 68 .qmin(std::numeric_limits<int8_t>::min()) 69 .qmax(std::numeric_limits<int8_t>::max()) 70 $elif OUTPUT_DATATYPE == "QU8": 71 .qmin(std::numeric_limits<uint8_t>::min()) 72 .qmax(std::numeric_limits<uint8_t>::max()) 73 .Test(${", ".join(TEST_ARGS)}); 74 } 75 } 76 77 TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) { 78 $if ISA_CHECK: 79 ${ISA_CHECK}; 80 for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) { 81 VCvtMicrokernelTester() 82 .batch_size(batch_size) 83 $if OUTPUT_DATATYPE == "QS8": 84 .qmin(std::numeric_limits<int8_t>::min()) 85 .qmax(std::numeric_limits<int8_t>::max()) 86 $elif OUTPUT_DATATYPE == "QU8": 87 .qmin(std::numeric_limits<uint8_t>::min()) 88 .qmax(std::numeric_limits<uint8_t>::max()) 89 .Test(${", ".join(TEST_ARGS)}); 90 } 91 } 92 93TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) { 94 $if ISA_CHECK: 95 ${ISA_CHECK}; 96 for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) { 97 VCvtMicrokernelTester() 98 .batch_size(batch_size) 99 $if OUTPUT_DATATYPE == "QS8": 100 .qmin(std::numeric_limits<int8_t>::min()) 101 .qmax(std::numeric_limits<int8_t>::max()) 102 $elif OUTPUT_DATATYPE == "QU8": 103 .qmin(std::numeric_limits<uint8_t>::min()) 104 .qmax(std::numeric_limits<uint8_t>::max()) 105 .Test(${", ".join(TEST_ARGS)}); 106 } 107} 108 109$if INPUT_DATATYPE.startswith("Q") or OUTPUT_DATATYPE.startswith("Q"): 110 TEST(${TEST_NAME}, scale) { 111 $if ISA_CHECK: 112 ${ISA_CHECK}; 113 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 114 VCvtMicrokernelTester() 115 .batch_size(batch_size) 116 .scale(50) 117 $if OUTPUT_DATATYPE == "QS8": 118 .qmin(std::numeric_limits<int8_t>::min()) 119 .qmax(std::numeric_limits<int8_t>::max()) 120 $elif OUTPUT_DATATYPE == "QU8": 121 .output_zero_point(100) 122 .qmin(std::numeric_limits<uint8_t>::min()) 123 .qmax(std::numeric_limits<uint8_t>::max()) 124 .Test(${", ".join(TEST_ARGS)}); 125 } 126 } 127 128$if INPUT_DATATYPE in ["QS8", "QU8"]: 129 TEST(${TEST_NAME}, input_zero_point) { 130 $if ISA_CHECK: 131 ${ISA_CHECK}; 132 for (int16_t input_zero_point = 0; input_zero_point < 5; input_zero_point += 2) { 133 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 134 VCvtMicrokernelTester() 135 .batch_size(batch_size) 136 .input_zero_point(input_zero_point) 137 $if OUTPUT_DATATYPE == "QS8": 138 .qmin(std::numeric_limits<int8_t>::min()) 139 .qmax(std::numeric_limits<int8_t>::max()) 140 $elif OUTPUT_DATATYPE == "QU8": 141 .qmin(std::numeric_limits<uint8_t>::min()) 142 .qmax(std::numeric_limits<uint8_t>::max()) 143 .Test(${", ".join(TEST_ARGS)}); 144 } 145 } 146 } 147 148$if OUTPUT_DATATYPE in ["QS8", "QU8"]: 149 TEST(${TEST_NAME}, output_zero_point) { 150 $if ISA_CHECK: 151 ${ISA_CHECK}; 152 for (int16_t output_zero_point = 0; output_zero_point < 5; output_zero_point += 2) { 153 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 154 VCvtMicrokernelTester() 155 .batch_size(batch_size) 156 .output_zero_point(output_zero_point) 157 $if OUTPUT_DATATYPE == "QS8": 158 .qmin(std::numeric_limits<int8_t>::min()) 159 .qmax(std::numeric_limits<int8_t>::max()) 160 $elif OUTPUT_DATATYPE == "QU8": 161 .qmin(std::numeric_limits<uint8_t>::min()) 162 .qmax(std::numeric_limits<uint8_t>::max()) 163 .Test(${", ".join(TEST_ARGS)}); 164 } 165 } 166 } 167 168 $if INPUT_DATATYPE == "F32": 169 TEST(${TEST_NAME}, saturation) { 170 $if ISA_CHECK: 171 ${ISA_CHECK}; 172 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 173 VCvtMicrokernelTester() 174 .batch_size(batch_size) 175 .scale(500) 176 $if OUTPUT_DATATYPE == "QS8": 177 .qmin(std::numeric_limits<int8_t>::min()) 178 .qmax(std::numeric_limits<int8_t>::max()) 179 $elif OUTPUT_DATATYPE == "QU8": 180 .output_zero_point(128) 181 .qmin(std::numeric_limits<uint8_t>::min()) 182 .qmax(std::numeric_limits<uint8_t>::max()) 183 .Test(${", ".join(TEST_ARGS)}); 184 } 185 } 186 187 TEST(${TEST_NAME}, overflow) { 188 $if ISA_CHECK: 189 ${ISA_CHECK}; 190 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 191 VCvtMicrokernelTester() 192 .batch_size(batch_size) 193 .scale(4294967296.0f) 194 $if OUTPUT_DATATYPE == "QS8": 195 .qmin(std::numeric_limits<int8_t>::min()) 196 .qmax(std::numeric_limits<int8_t>::max()) 197 $elif OUTPUT_DATATYPE == "QU8": 198 .qmin(std::numeric_limits<uint8_t>::min()) 199 .qmax(std::numeric_limits<uint8_t>::max()) 200 .Test(${", ".join(TEST_ARGS)}); 201 } 202 } 203 204$if INPUT_DATATYPE == "F32" and OUTPUT_DATATYPE == "QS8": 205 TEST(${TEST_NAME}, qmin) { 206 $if ISA_CHECK: 207 ${ISA_CHECK}; 208 for (int16_t qmin = -128; qmin < 127; qmin += 51) { 209 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 210 VCvtMicrokernelTester() 211 .batch_size(batch_size) 212 .scale(500) 213 .qmin(qmin) 214 .qmax(std::numeric_limits<int8_t>::max()) 215 .Test(${", ".join(TEST_ARGS)}); 216 } 217 } 218 } 219 220 TEST(${TEST_NAME}, qmax) { 221 $if ISA_CHECK: 222 ${ISA_CHECK}; 223 for (int16_t qmax = -127; qmax <= 127; qmax += 51) { 224 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 225 VCvtMicrokernelTester() 226 .batch_size(batch_size) 227 .scale(500) 228 .qmin(std::numeric_limits<int8_t>::min()) 229 .qmax(qmax) 230 .Test(${", ".join(TEST_ARGS)}); 231 } 232 } 233 } 234 235$if INPUT_DATATYPE == "F32" and OUTPUT_DATATYPE == "QU8": 236 TEST(${TEST_NAME}, qmin) { 237 $if ISA_CHECK: 238 ${ISA_CHECK}; 239 for (int16_t qmin = 0; qmin < 255; qmin += 51) { 240 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 241 VCvtMicrokernelTester() 242 .batch_size(batch_size) 243 .scale(500) 244 .output_zero_point(128) 245 .qmin(qmin) 246 .qmax(std::numeric_limits<uint8_t>::max()) 247 .Test(${", ".join(TEST_ARGS)}); 248 } 249 } 250 } 251 252 TEST(${TEST_NAME}, qmax) { 253 $if ISA_CHECK: 254 ${ISA_CHECK}; 255 for (int16_t qmax = 1; qmax <= 255; qmax += 51) { 256 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 257 VCvtMicrokernelTester() 258 .batch_size(batch_size) 259 .scale(500) 260 .output_zero_point(128) 261 .qmin(std::numeric_limits<uint8_t>::min()) 262 .qmax(qmax) 263 .Test(${", ".join(TEST_ARGS)}); 264 } 265 } 266 } 267""" 268 269 270def generate_test_cases(ukernel, init_fn, input_datatype, output_datatype, 271 batch_tile, isa): 272 """Generates all tests cases for a Vector Convert Operation micro-kernel. 273 274 Args: 275 ukernel: C name of the micro-kernel function. 276 init_fn: C name of the function to initialize microkernel parameters. 277 input_datatype: input conversion data type. 278 output_datatype: output conversion data type. 279 batch_tile: Number of batch elements processed per one iteration of the 280 inner loop of the micro-kernel. 281 isa: instruction set required to run the micro-kernel. Generated unit test 282 will skip execution if the host processor doesn't support this ISA. 283 284 Returns: 285 Code for the test case. 286 """ 287 _, test_name = ukernel.split("_", 1) 288 test_args = [ukernel] 289 if init_fn: 290 test_args.append(init_fn) 291 return xngen.preprocess(CVT_TEST_TEMPLATE, { 292 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 293 "TEST_ARGS": test_args, 294 "BATCH_TILE": batch_tile, 295 "INPUT_DATATYPE": input_datatype.upper(), 296 "OUTPUT_DATATYPE": output_datatype.upper(), 297 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 298 }) 299 300 301def main(args): 302 options = parser.parse_args(args) 303 304 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 305 spec_yaml = yaml.safe_load(spec_file) 306 if not isinstance(spec_yaml, list): 307 raise ValueError("expected a list of micro-kernels in the spec") 308 309 tests = """\ 310// Copyright 2021 Google LLC 311// 312// This source code is licensed under the BSD-style license found in the 313// LICENSE file in the root directory of this source tree. 314// 315// Auto-generated file. Do not edit! 316// Specification: {specification} 317// Generator: {generator} 318 319 320#include <gtest/gtest.h> 321 322#include <xnnpack/common.h> 323#include <xnnpack/isa-checks.h> 324 325#include <xnnpack/vcvt.h> 326#include "vcvt-microkernel-tester.h" 327""".format(specification=options.spec, generator=sys.argv[0]) 328 329 for ukernel_spec in spec_yaml: 330 name = ukernel_spec["name"] 331 init_fn = ukernel_spec.get("init") 332 input_datatype, output_datatype, batch_tile, arch, isa = \ 333 split_ukernel_name(name) 334 335 # specification can override architecture 336 arch = ukernel_spec.get("arch", arch) 337 338 test_case = generate_test_cases( 339 name, init_fn, input_datatype, output_datatype, batch_tile, isa) 340 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 341 342 txt_changed = True 343 if os.path.exists(options.output): 344 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 345 txt_changed = output_file.read() != tests 346 347 if txt_changed: 348 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 349 output_file.write(tests) 350 351 352if __name__ == "__main__": 353 main(sys.argv[1:]) 354