1#!/usr/bin/env python 2# Copyright 2022 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16from primes import next_prime 17import xngen 18import xnncommon 19 20 21parser = argparse.ArgumentParser(description='BFly4 microkernel test generator') 22parser.add_argument("-s", "--spec", metavar="FILE", required=True, 23 help="Specification (YAML) file") 24parser.add_argument("-o", "--output", metavar="FILE", required=True, 25 help='Output (C++ source) file') 26parser.set_defaults(defines=list()) 27 28 29def split_ukernel_name(name): 30 match = re.fullmatch(r"xnn_cs16_bfly4(_samples(\d+))?_ukernel__(.+)(_x(\d+))?", name) 31 assert match is not None, name 32 if match.group(2): 33 samples = int(match.group(2)) 34 else: 35 samples = 0 36 if match.group(5): 37 samples_tile = int(match.group(5)) 38 else: 39 samples_tile = 1 40 41 arch, isa = xnncommon.parse_target_name(target_name=match.group(3)) 42 return samples, samples_tile, arch, isa 43 44 45BFLY4_TEST_TEMPLATE = """\ 46TEST(${TEST_NAME}, samples_eq_1) { 47 $if ISA_CHECK: 48 ${ISA_CHECK}; 49 BFly4MicrokernelTester() 50 .samples(1) 51 .stride(64) 52 .Test(${", ".join(TEST_ARGS)}); 53} 54 55$if SAMPLES == 0: 56 TEST(${TEST_NAME}, samples_eq_4) { 57 $if ISA_CHECK: 58 ${ISA_CHECK}; 59 BFly4MicrokernelTester() 60 .samples(4) 61 .stride(16) 62 .Test(${", ".join(TEST_ARGS)}); 63 } 64 65 TEST(${TEST_NAME}, samples_eq_16) { 66 $if ISA_CHECK: 67 ${ISA_CHECK}; 68 BFly4MicrokernelTester() 69 .samples(16) 70 .stride(4) 71 .Test(${", ".join(TEST_ARGS)}); 72 } 73 74 TEST(${TEST_NAME}, samples_eq_64) { 75 $if ISA_CHECK: 76 ${ISA_CHECK}; 77 BFly4MicrokernelTester() 78 .samples(64) 79 .stride(1) 80 .Test(${", ".join(TEST_ARGS)}); 81 } 82 83""" 84 85 86def generate_test_cases(ukernel, samples, samples_tile, isa): 87 """Generates all tests cases for a BFly4 micro-kernel. 88 89 Args: 90 ukernel: C name of the micro-kernel function. 91 samples: fixed number of samples for specialized samples1 microkernel. 92 samples_tile: Number of samples processed per one iteration of the inner 93 loop of the micro-kernel. 94 isa: instruction set required to run the micro-kernel. Generated unit test 95 will skip execution if the host processor doesn't support this ISA. 96 97 Returns: 98 Code for the test case. 99 """ 100 _, test_name = ukernel.split("_", 1) 101 _, datatype, ukernel_type, _ = ukernel.split("_", 3) 102 return xngen.preprocess(BFLY4_TEST_TEMPLATE, { 103 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 104 "TEST_ARGS": [ukernel], 105 "DATATYPE": datatype, 106 "SAMPLES": samples, 107 "SAMPLE_TILE": samples_tile, 108 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 109 "next_prime": next_prime, 110 }) 111 112 113def main(args): 114 options = parser.parse_args(args) 115 116 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 117 spec_yaml = yaml.safe_load(spec_file) 118 if not isinstance(spec_yaml, list): 119 raise ValueError("expected a list of micro-kernels in the spec") 120 121 tests = """\ 122// Copyright 2022 Google LLC 123// 124// This source code is licensed under the BSD-style license found in the 125// LICENSE file in the root directory of this source tree. 126// 127// Auto-generated file. Do not edit! 128// Specification: {specification} 129// Generator: {generator} 130 131 132#include <gtest/gtest.h> 133 134#include <xnnpack/common.h> 135#include <xnnpack/isa-checks.h> 136 137#include <xnnpack/fft.h> 138#include "bfly4-microkernel-tester.h" 139""".format(specification=options.spec, generator=sys.argv[0]) 140 141 for ukernel_spec in spec_yaml: 142 name = ukernel_spec["name"] 143 samples, samples_tile, arch, isa = split_ukernel_name(name) 144 145 # specification can override architecture 146 arch = ukernel_spec.get("arch", arch) 147 148 test_case = generate_test_cases(name, samples, samples_tile, isa) 149 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 150 151 txt_changed = True 152 if os.path.exists(options.output): 153 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 154 txt_changed = output_file.read() != tests 155 156 if txt_changed: 157 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 158 output_file.write(tests) 159 160 161if __name__ == "__main__": 162 main(sys.argv[1:]) 163