1#!/usr/bin/env python 2# Copyright 2022 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16import xngen 17import xnncommon 18 19 20parser = argparse.ArgumentParser( 21 description='Vector Leaky ReLU operation microkernel test generator') 22parser.add_argument("-s", "--spec", metavar="FILE", required=True, 23 help="Specification (YAML) file") 24parser.add_argument("-o", "--output", metavar="FILE", required=True, 25 help='Output (C++ source) file') 26parser.set_defaults(defines=list()) 27 28 29def split_ukernel_name(name): 30 match = re.fullmatch(r"xnn_(qs8|qu8)_vlrelu_ukernel__(.+)_x(\d+)", name) 31 if match is None: 32 raise ValueError("Unexpected microkernel name: " + name) 33 34 datatype = match.group(1) 35 batch_tile = int(match.group(3)) 36 37 arch, isa = xnncommon.parse_target_name(target_name=match.group(2)) 38 return datatype, batch_tile, arch, isa 39 40 41LRELU_TEST_TEMPLATE = """\ 42TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) { 43 $if ISA_CHECK: 44 ${ISA_CHECK}; 45 VLReLUMicrokernelTester() 46 .batch_size(${BATCH_TILE}) 47 $if DATATYPE == "QU8": 48 .input_zero_point(150) 49 .output_zero_point(100) 50 .Test(${", ".join(TEST_ARGS)}); 51} 52 53$if BATCH_TILE > 1: 54 TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) { 55 $if ISA_CHECK: 56 ${ISA_CHECK}; 57 for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) { 58 VLReLUMicrokernelTester() 59 .batch_size(batch_size) 60 $if DATATYPE == "QU8": 61 .input_zero_point(150) 62 .output_zero_point(100) 63 .Test(${", ".join(TEST_ARGS)}); 64 } 65 } 66 67 TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) { 68 $if ISA_CHECK: 69 ${ISA_CHECK}; 70 for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) { 71 VLReLUMicrokernelTester() 72 .batch_size(batch_size) 73 $if DATATYPE == "QU8": 74 .input_zero_point(150) 75 .output_zero_point(100) 76 .Test(${", ".join(TEST_ARGS)}); 77 } 78 } 79 80TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) { 81 $if ISA_CHECK: 82 ${ISA_CHECK}; 83 for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) { 84 VLReLUMicrokernelTester() 85 .batch_size(batch_size) 86 $if DATATYPE == "QU8": 87 .input_zero_point(150) 88 .output_zero_point(100) 89 .Test(${", ".join(TEST_ARGS)}); 90 } 91} 92 93TEST(${TEST_NAME}, positive_scale) { 94 $if ISA_CHECK: 95 ${ISA_CHECK}; 96 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 97 for (float positive_scale : std::vector<float>({1.0f / 256.0f, 0.3f, 1.3f, 128.0f})) { 98 VLReLUMicrokernelTester() 99 .batch_size(batch_size) 100 .positive_scale(positive_scale) 101 $if DATATYPE == "QU8": 102 .input_zero_point(150) 103 .output_zero_point(100) 104 .Test(${", ".join(TEST_ARGS)}); 105 } 106 } 107} 108 109TEST(${TEST_NAME}, negative_scale) { 110 $if ISA_CHECK: 111 ${ISA_CHECK}; 112 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 113 for (float negative_scale : std::vector<float>({-127.99609375f, -1.3f, -0.3f, -1.0f / 256.0f, 1 / 256.0f, 0.3f, 1.3f, 128.0f})) { 114 VLReLUMicrokernelTester() 115 .batch_size(batch_size) 116 .negative_scale(negative_scale) 117 $if DATATYPE == "QU8": 118 .input_zero_point(150) 119 .output_zero_point(100) 120 .Test(${", ".join(TEST_ARGS)}); 121 } 122 } 123} 124 125TEST(${TEST_NAME}, input_zero_point) { 126 $if ISA_CHECK: 127 ${ISA_CHECK}; 128 for (int16_t input_zero_point = 2; input_zero_point < 10; input_zero_point += 3) { 129 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 130 VLReLUMicrokernelTester() 131 .batch_size(batch_size) 132 .input_zero_point(input_zero_point) 133 $if DATATYPE == "QU8": 134 .output_zero_point(100) 135 .Test(${", ".join(TEST_ARGS)}); 136 } 137 } 138} 139 140TEST(${TEST_NAME}, output_zero_point) { 141 $if ISA_CHECK: 142 ${ISA_CHECK}; 143 for (int16_t output_zero_point = 2; output_zero_point < 10; output_zero_point += 3) { 144 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 145 VLReLUMicrokernelTester() 146 .batch_size(batch_size) 147 $if DATATYPE == "QU8": 148 .input_zero_point(150) 149 .output_zero_point(output_zero_point) 150 .Test(${", ".join(TEST_ARGS)}); 151 } 152 } 153} 154""" 155 156 157def generate_test_cases(ukernel, init_fn, datatype, batch_tile, isa): 158 """Generates all tests cases for a Vector Leaky ReLU micro-kernel. 159 160 Args: 161 ukernel: C name of the micro-kernel function. 162 init_fn: C name of the function to initialize microkernel parameters. 163 datatype: data type. 164 batch_tile: Number of batch elements processed per one iteration of the 165 inner loop of the micro-kernel. 166 isa: instruction set required to run the micro-kernel. Generated unit test 167 will skip execution if the host processor doesn't support this ISA. 168 169 Returns: 170 Code for the test case. 171 """ 172 _, test_name = ukernel.split("_", 1) 173 test_args = [ukernel] 174 if init_fn: 175 test_args.append(init_fn) 176 return xngen.preprocess(LRELU_TEST_TEMPLATE, { 177 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 178 "TEST_ARGS": test_args, 179 "BATCH_TILE": batch_tile, 180 "DATATYPE": datatype.upper(), 181 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 182 }) 183 184 185def main(args): 186 options = parser.parse_args(args) 187 188 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 189 spec_yaml = yaml.safe_load(spec_file) 190 if not isinstance(spec_yaml, list): 191 raise ValueError("expected a list of micro-kernels in the spec") 192 193 tests = """\ 194// Copyright 2022 Google LLC 195// 196// This source code is licensed under the BSD-style license found in the 197// LICENSE file in the root directory of this source tree. 198// 199// Auto-generated file. Do not edit! 200// Specification: {specification} 201// Generator: {generator} 202 203 204#include <gtest/gtest.h> 205 206#include <xnnpack/common.h> 207#include <xnnpack/isa-checks.h> 208 209#include <xnnpack/vlrelu.h> 210#include "vlrelu-microkernel-tester.h" 211""".format(specification=options.spec, generator=sys.argv[0]) 212 213 for ukernel_spec in spec_yaml: 214 name = ukernel_spec["name"] 215 init_fn = ukernel_spec.get("init") 216 datatype, batch_tile, arch, isa = split_ukernel_name(name) 217 218 # specification can override architecture 219 arch = ukernel_spec.get("arch", arch) 220 221 test_case = generate_test_cases( 222 name, init_fn, datatype, batch_tile, isa) 223 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 224 225 txt_changed = True 226 if os.path.exists(options.output): 227 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 228 txt_changed = output_file.read() != tests 229 230 if txt_changed: 231 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 232 output_file.write(tests) 233 234 235if __name__ == "__main__": 236 main(sys.argv[1:]) 237