1#!/usr/bin/env python 2# Copyright 2022 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16from primes import next_prime 17import xngen 18import xnncommon 19 20 21parser = argparse.ArgumentParser(description='Window microkernel test generator') 22parser.add_argument("-s", "--spec", metavar="FILE", required=True, 23 help="Specification (YAML) file") 24parser.add_argument("-o", "--output", metavar="FILE", required=True, 25 help='Output (C++ source) file') 26parser.set_defaults(defines=list()) 27 28 29def split_ukernel_name(name): 30 shift = 0 31 row_tile = 1 32 match = re.fullmatch(r"xnn_s16_window(_shift(\d+))?_ukernel__(.+)_x(\d+)", name) 33 assert match is not None 34 if match.group(2): 35 shift = int(match.group(2)) 36 batch_tile = int(match.group(4)) 37 38 arch, isa = xnncommon.parse_target_name(target_name=match.group(3)) 39 return shift, row_tile, batch_tile, arch, isa 40 41 42WINDOW_TEST_TEMPLATE = """\ 43TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) { 44 $if ISA_CHECK: 45 ${ISA_CHECK}; 46 WindowMicrokernelTester() 47 .rows(1) 48 .batch(${BATCH_TILE}) 49 .shift(${SHIFT}) 50 .Test(${", ".join(TEST_ARGS)}); 51} 52 53$if BATCH_TILE > 1: 54 TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) { 55 $if ISA_CHECK: 56 ${ISA_CHECK}; 57 for (size_t batch = ${BATCH_TILE*2}; batch < ${BATCH_TILE*10}; batch += ${BATCH_TILE}) { 58 WindowMicrokernelTester() 59 .batch(batch) 60 .shift(${SHIFT}) 61 .Test(${", ".join(TEST_ARGS)}); 62 } 63 } 64 65 TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) { 66 $if ISA_CHECK: 67 ${ISA_CHECK}; 68 for (size_t batch = 1; batch < ${BATCH_TILE}; batch++) { 69 WindowMicrokernelTester() 70 .batch(batch) 71 .shift(${SHIFT}) 72 .Test(${", ".join(TEST_ARGS)}); 73 } 74 } 75 76TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) { 77 $if ISA_CHECK: 78 ${ISA_CHECK}; 79 for (size_t batch = ${BATCH_TILE+1}; batch < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch++) { 80 WindowMicrokernelTester() 81 .batch(batch) 82 .shift(${SHIFT}) 83 .Test(${", ".join(TEST_ARGS)}); 84 } 85} 86 87$if ROW_TILE > 1: 88 TEST(${TEST_NAME}, rows_lt_${ROW_TILE}) { 89 $if ISA_CHECK: 90 ${ISA_CHECK}; 91 for (size_t rows = 1; rows < ${ROW_TILE}; rows++) { 92 for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) { 93 WindowMicrokernelTester() 94 .rows(rows) 95 .batch(batch) 96 .shift(${SHIFT}) 97 .Test(${", ".join(TEST_ARGS)}); 98 } 99 } 100 } 101 102 TEST(${TEST_NAME}, rows_div_${ROW_TILE}) { 103 $if ISA_CHECK: 104 ${ISA_CHECK}; 105 for (size_t rows = ${ROW_TILE*2}; rows <= ${ROW_TILE*4}; rows += ${ROW_TILE}) { 106 for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) { 107 WindowMicrokernelTester() 108 .rows(rows) 109 .batch(batch) 110 .shift(${SHIFT}) 111 .Test(${", ".join(TEST_ARGS)}); 112 } 113 } 114 } 115 116TEST(${TEST_NAME}, rows_gt_${ROW_TILE}) { 117 $if ISA_CHECK: 118 ${ISA_CHECK}; 119 for (size_t rows = ${ROW_TILE+1}; rows < ${ROW_TILE*2}; rows++) { 120 for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) { 121 WindowMicrokernelTester() 122 .rows(rows) 123 .batch(batch) 124 .shift(${SHIFT}) 125 .Test(${", ".join(TEST_ARGS)}); 126 } 127 } 128} 129 130TEST(${TEST_NAME}, inplace) { 131 $if ISA_CHECK: 132 ${ISA_CHECK}; 133 for (size_t rows = 1; rows <= ${ROW_TILE*3}; rows += ${max(1, ROW_TILE-1)}) { 134 for (size_t batch = 1; batch <= ${BATCH_TILE*5}; batch += ${max(1, BATCH_TILE-1)}) { 135 WindowMicrokernelTester() 136 .rows(rows) 137 .batch(batch) 138 .shift(${SHIFT}) 139 .inplace(true) 140 .iterations(1) 141 .Test(${", ".join(TEST_ARGS)}); 142 } 143 } 144} 145 146$if SHIFT == 0: 147 TEST(${TEST_NAME}, shift) { 148 $if ISA_CHECK: 149 ${ISA_CHECK}; 150 for (uint32_t shift = 0; shift < 32; shift++) { 151 WindowMicrokernelTester() 152 .rows(${ROW_TILE}) 153 .batch(${BATCH_TILE}) 154 .shift(shift) 155 .Test(${", ".join(TEST_ARGS)}); 156 } 157 } 158 159""" 160 161 162def generate_test_cases(ukernel, shift, row_tile, batch_tile, isa): 163 """Generates all tests cases for a Window micro-kernel. 164 165 Args: 166 ukernel: C name of the micro-kernel function. 167 shift: Shift by constant value. 168 row_tile: Number of rows (pixels) processed per one iteration of the outer 169 loop of the micro-kernel. 170 batch_tile: Number of batch processed per one iteration of the inner 171 loop of the micro-kernel. 172 isa: instruction set required to run the micro-kernel. Generated unit test 173 will skip execution if the host processor doesn't support this ISA. 174 175 Returns: 176 Code for the test case. 177 """ 178 _, test_name = ukernel.split("_", 1) 179 _, datatype, ukernel_type, _ = ukernel.split("_", 3) 180 return xngen.preprocess(WINDOW_TEST_TEMPLATE, { 181 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 182 "TEST_ARGS": [ukernel], 183 "DATATYPE": datatype, 184 "SHIFT": shift, 185 "ROW_TILE": row_tile, 186 "BATCH_TILE": batch_tile, 187 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 188 "next_prime": next_prime, 189 }) 190 191 192def main(args): 193 options = parser.parse_args(args) 194 195 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 196 spec_yaml = yaml.safe_load(spec_file) 197 if not isinstance(spec_yaml, list): 198 raise ValueError("expected a list of micro-kernels in the spec") 199 200 tests = """\ 201// Copyright 2022 Google LLC 202// 203// This source code is licensed under the BSD-style license found in the 204// LICENSE file in the root directory of this source tree. 205// 206// Auto-generated file. Do not edit! 207// Specification: {specification} 208// Generator: {generator} 209 210 211#include <gtest/gtest.h> 212 213#include <xnnpack/common.h> 214#include <xnnpack/isa-checks.h> 215 216#include <xnnpack/window.h> 217#include "window-microkernel-tester.h" 218""".format(specification=options.spec, generator=sys.argv[0]) 219 220 for ukernel_spec in spec_yaml: 221 name = ukernel_spec["name"] 222 shift, row_tile, batch_tile, arch, isa = split_ukernel_name(name) 223 224 # specification can override architecture 225 arch = ukernel_spec.get("arch", arch) 226 227 test_case = generate_test_cases(name, shift, row_tile, batch_tile, isa) 228 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 229 230 txt_changed = True 231 if os.path.exists(options.output): 232 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 233 txt_changed = output_file.read() != tests 234 235 if txt_changed: 236 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 237 output_file.write(tests) 238 239 240if __name__ == "__main__": 241 main(sys.argv[1:]) 242