xref: /aosp_15_r20/external/XNNPACK/tools/generate-dwconv-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker#!/usr/bin/env python
2*4bdc9457SAndroid Build Coastguard Worker# Copyright 2019 Google LLC
3*4bdc9457SAndroid Build Coastguard Worker#
4*4bdc9457SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*4bdc9457SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*4bdc9457SAndroid Build Coastguard Worker
7*4bdc9457SAndroid Build Coastguard Workerimport argparse
8*4bdc9457SAndroid Build Coastguard Workerimport bisect
9*4bdc9457SAndroid Build Coastguard Workerimport codecs
10*4bdc9457SAndroid Build Coastguard Workerimport math
11*4bdc9457SAndroid Build Coastguard Workerimport os
12*4bdc9457SAndroid Build Coastguard Workerimport sys
13*4bdc9457SAndroid Build Coastguard Workerimport yaml
14*4bdc9457SAndroid Build Coastguard Worker
15*4bdc9457SAndroid Build Coastguard Workersys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16*4bdc9457SAndroid Build Coastguard Workerfrom primes import next_prime
17*4bdc9457SAndroid Build Coastguard Workerimport xngen
18*4bdc9457SAndroid Build Coastguard Workerimport xnncommon
19*4bdc9457SAndroid Build Coastguard Worker
20*4bdc9457SAndroid Build Coastguard Worker
21*4bdc9457SAndroid Build Coastguard Workerparser = argparse.ArgumentParser(description='XNNPACK generator')
22*4bdc9457SAndroid Build Coastguard Workerparser.add_argument("-s", "--spec", metavar="FILE", required=True,
23*4bdc9457SAndroid Build Coastguard Worker                    help="Spec (YAML) file")
24*4bdc9457SAndroid Build Coastguard Workerparser.add_argument("-o", "--output", metavar="FILE", required=True,
25*4bdc9457SAndroid Build Coastguard Worker                    help='Output (C++ source) file')
26*4bdc9457SAndroid Build Coastguard Workerparser.set_defaults(defines=list())
27*4bdc9457SAndroid Build Coastguard Worker
28*4bdc9457SAndroid Build Coastguard Workerdef split_ukernel_name(name):
29*4bdc9457SAndroid Build Coastguard Worker  common_name, target_name = name.split("__", 1)
30*4bdc9457SAndroid Build Coastguard Worker  common_parts = common_name.split("_")
31*4bdc9457SAndroid Build Coastguard Worker  param_spec = common_parts[-1]
32*4bdc9457SAndroid Build Coastguard Worker  assert param_spec.startswith("up")
33*4bdc9457SAndroid Build Coastguard Worker
34*4bdc9457SAndroid Build Coastguard Worker  if len(param_spec[2:].split("p")) > 1:
35*4bdc9457SAndroid Build Coastguard Worker    tile_part, kernel_part = param_spec[2:].split("p", 1)
36*4bdc9457SAndroid Build Coastguard Worker    primary_tile = int(tile_part)
37*4bdc9457SAndroid Build Coastguard Worker    cr, kr = map(int, kernel_part.split("x"))
38*4bdc9457SAndroid Build Coastguard Worker  else:
39*4bdc9457SAndroid Build Coastguard Worker    primary_tile = 0;
40*4bdc9457SAndroid Build Coastguard Worker    cr, kr = map(int, param_spec[2:].split("x"))
41*4bdc9457SAndroid Build Coastguard Worker  arch, isa = xnncommon.parse_target_name(target_name)
42*4bdc9457SAndroid Build Coastguard Worker
43*4bdc9457SAndroid Build Coastguard Worker  requantization = common_parts[-3]
44*4bdc9457SAndroid Build Coastguard Worker  if requantization not in ["fp32", "rndnu"]:
45*4bdc9457SAndroid Build Coastguard Worker    requantization = None
46*4bdc9457SAndroid Build Coastguard Worker
47*4bdc9457SAndroid Build Coastguard Worker  return primary_tile, cr, kr, requantization, arch, isa
48*4bdc9457SAndroid Build Coastguard Worker
49*4bdc9457SAndroid Build Coastguard Worker
50*4bdc9457SAndroid Build Coastguard WorkerDWCONV_TEST_CODE = """\
51*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, c_eq_${CBLOCK}) {
52*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
53*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
54*4bdc9457SAndroid Build Coastguard Worker  DWConvMicrokernelTester()
55*4bdc9457SAndroid Build Coastguard Worker    .cr(${CR})
56*4bdc9457SAndroid Build Coastguard Worker    .kr(${KR})
57*4bdc9457SAndroid Build Coastguard Worker    .channels(${CBLOCK})
58*4bdc9457SAndroid Build Coastguard Worker    .Test(${", ".join(TEST_ARGS)});
59*4bdc9457SAndroid Build Coastguard Worker}
60*4bdc9457SAndroid Build Coastguard Worker
61*4bdc9457SAndroid Build Coastguard Worker$if IS_PIPELINED:
62*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, c_eq_${CBLOCK * 2}) {
63*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
64*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
65*4bdc9457SAndroid Build Coastguard Worker    DWConvMicrokernelTester()
66*4bdc9457SAndroid Build Coastguard Worker      .cr(${CR})
67*4bdc9457SAndroid Build Coastguard Worker      .kr(${KR})
68*4bdc9457SAndroid Build Coastguard Worker      .channels(${CBLOCK * 2})
69*4bdc9457SAndroid Build Coastguard Worker      .Test(${", ".join(TEST_ARGS)});
70*4bdc9457SAndroid Build Coastguard Worker  }
71*4bdc9457SAndroid Build Coastguard Worker
72*4bdc9457SAndroid Build Coastguard Worker$if CBLOCK > 1:
73*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, c_div_${CBLOCK}) {
74*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
75*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
76*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t channels = ${ADJCBLOCK + CBLOCK}; channels < ${CR * 16}; channels += ${CR * 3}) {
77*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
78*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
79*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
80*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
81*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
82*4bdc9457SAndroid Build Coastguard Worker    }
83*4bdc9457SAndroid Build Coastguard Worker  }
84*4bdc9457SAndroid Build Coastguard Worker
85*4bdc9457SAndroid Build Coastguard Worker  $if ACTIVATION == "MINMAX":
86*4bdc9457SAndroid Build Coastguard Worker    TEST(${TEST_NAME}, c_div_${CBLOCK}_with_qmin) {
87*4bdc9457SAndroid Build Coastguard Worker      $if ISA_CHECK:
88*4bdc9457SAndroid Build Coastguard Worker        ${ISA_CHECK};
89*4bdc9457SAndroid Build Coastguard Worker      for (uint32_t channels = ${ADJCBLOCK + CBLOCK}; channels < ${CR * 16}; channels += ${CR * 3}) {
90*4bdc9457SAndroid Build Coastguard Worker        DWConvMicrokernelTester()
91*4bdc9457SAndroid Build Coastguard Worker          .cr(${CR})
92*4bdc9457SAndroid Build Coastguard Worker          .kr(${KR})
93*4bdc9457SAndroid Build Coastguard Worker          .channels(channels)
94*4bdc9457SAndroid Build Coastguard Worker          .qmin(128)
95*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
96*4bdc9457SAndroid Build Coastguard Worker      }
97*4bdc9457SAndroid Build Coastguard Worker    }
98*4bdc9457SAndroid Build Coastguard Worker
99*4bdc9457SAndroid Build Coastguard Worker    TEST(${TEST_NAME}, c_div_${CBLOCK}_with_qmax) {
100*4bdc9457SAndroid Build Coastguard Worker      $if ISA_CHECK:
101*4bdc9457SAndroid Build Coastguard Worker        ${ISA_CHECK};
102*4bdc9457SAndroid Build Coastguard Worker      for (uint32_t channels = ${ADJCBLOCK + CBLOCK}; channels < ${CR * 16}; channels += ${CR * 3}) {
103*4bdc9457SAndroid Build Coastguard Worker        DWConvMicrokernelTester()
104*4bdc9457SAndroid Build Coastguard Worker          .cr(${CR})
105*4bdc9457SAndroid Build Coastguard Worker          .kr(${KR})
106*4bdc9457SAndroid Build Coastguard Worker          .channels(channels)
107*4bdc9457SAndroid Build Coastguard Worker          .qmax(128)
108*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
109*4bdc9457SAndroid Build Coastguard Worker      }
110*4bdc9457SAndroid Build Coastguard Worker    }
111*4bdc9457SAndroid Build Coastguard Worker
112*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, c_lt_${ADJCBLOCK}) {
113*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
114*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
115*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t channels = 1; channels < ${ADJCBLOCK}; channels++) {
116*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
117*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
118*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
119*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
120*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
121*4bdc9457SAndroid Build Coastguard Worker    }
122*4bdc9457SAndroid Build Coastguard Worker  }
123*4bdc9457SAndroid Build Coastguard Worker
124*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, c_gt_${ADJCBLOCK}) {
125*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
126*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
127*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t channels = ${ADJCBLOCK + 1}; channels < ${10 if CBLOCK == 1 else ADJCBLOCK + CBLOCK}; channels++) {
128*4bdc9457SAndroid Build Coastguard Worker    DWConvMicrokernelTester()
129*4bdc9457SAndroid Build Coastguard Worker      .cr(${CR})
130*4bdc9457SAndroid Build Coastguard Worker      .kr(${KR})
131*4bdc9457SAndroid Build Coastguard Worker      .channels(channels)
132*4bdc9457SAndroid Build Coastguard Worker      .Test(${", ".join(TEST_ARGS)});
133*4bdc9457SAndroid Build Coastguard Worker  }
134*4bdc9457SAndroid Build Coastguard Worker}
135*4bdc9457SAndroid Build Coastguard Worker
136*4bdc9457SAndroid Build Coastguard Worker$if ACTIVATION == "MINMAX":
137*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, c_gt_${ADJCBLOCK}_with_qmin) {
138*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
139*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
140*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t channels = ${ADJCBLOCK + 1}; channels < ${10 if CBLOCK == 1 else ADJCBLOCK + CBLOCK}; channels++) {
141*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
142*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
143*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
144*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
145*4bdc9457SAndroid Build Coastguard Worker        .qmin(128)
146*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
147*4bdc9457SAndroid Build Coastguard Worker    }
148*4bdc9457SAndroid Build Coastguard Worker  }
149*4bdc9457SAndroid Build Coastguard Worker
150*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, c_gt_${ADJCBLOCK}_with_qmax) {
151*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
152*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
153*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t channels = ${ADJCBLOCK + 1}; channels < ${10 if CBLOCK == 1 else ADJCBLOCK + CBLOCK}; channels++) {
154*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
155*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
156*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
157*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
158*4bdc9457SAndroid Build Coastguard Worker        .qmax(128)
159*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
160*4bdc9457SAndroid Build Coastguard Worker    }
161*4bdc9457SAndroid Build Coastguard Worker  }
162*4bdc9457SAndroid Build Coastguard Worker
163*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, multipixel) {
164*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
165*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
166*4bdc9457SAndroid Build Coastguard Worker  for (size_t channels = 1; channels <= ${CBLOCK * 5}; channels += ${max(1, CBLOCK - 1)}) {
167*4bdc9457SAndroid Build Coastguard Worker    DWConvMicrokernelTester()
168*4bdc9457SAndroid Build Coastguard Worker      .cr(${CR})
169*4bdc9457SAndroid Build Coastguard Worker      .kr(${KR})
170*4bdc9457SAndroid Build Coastguard Worker      .channels(channels)
171*4bdc9457SAndroid Build Coastguard Worker      .width(3)
172*4bdc9457SAndroid Build Coastguard Worker      .Test(${", ".join(TEST_ARGS)});
173*4bdc9457SAndroid Build Coastguard Worker  }
174*4bdc9457SAndroid Build Coastguard Worker}
175*4bdc9457SAndroid Build Coastguard Worker
176*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, multipixel_with_step) {
177*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
178*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
179*4bdc9457SAndroid Build Coastguard Worker  for (size_t channels = 1; channels <= ${CBLOCK * 5}; channels += ${max(1, CBLOCK - 1)}) {
180*4bdc9457SAndroid Build Coastguard Worker    for (size_t step = 2; step <= ${KR}; step++) {
181*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
182*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
183*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
184*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
185*4bdc9457SAndroid Build Coastguard Worker        .width(3)
186*4bdc9457SAndroid Build Coastguard Worker        .step(step)
187*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
188*4bdc9457SAndroid Build Coastguard Worker    }
189*4bdc9457SAndroid Build Coastguard Worker  }
190*4bdc9457SAndroid Build Coastguard Worker}
191*4bdc9457SAndroid Build Coastguard Worker
192*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, multipixel_with_output_stride) {
193*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
194*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
195*4bdc9457SAndroid Build Coastguard Worker  for (size_t channels = 1; channels <= ${CBLOCK * 5}; channels += ${max(1, CBLOCK - 1)}) {
196*4bdc9457SAndroid Build Coastguard Worker    DWConvMicrokernelTester()
197*4bdc9457SAndroid Build Coastguard Worker      .cr(${CR})
198*4bdc9457SAndroid Build Coastguard Worker      .kr(${KR})
199*4bdc9457SAndroid Build Coastguard Worker      .channels(${CR})
200*4bdc9457SAndroid Build Coastguard Worker      .width(5)
201*4bdc9457SAndroid Build Coastguard Worker      .output_stride(${next_prime(CR * 5 + 1)})
202*4bdc9457SAndroid Build Coastguard Worker      .Test(${", ".join(TEST_ARGS)});
203*4bdc9457SAndroid Build Coastguard Worker  }
204*4bdc9457SAndroid Build Coastguard Worker}
205*4bdc9457SAndroid Build Coastguard Worker
206*4bdc9457SAndroid Build Coastguard Worker$if ACTIVATION == "MINMAX":
207*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, multipixel_with_qmin) {
208*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
209*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
210*4bdc9457SAndroid Build Coastguard Worker    for (size_t channels = 1; channels <= ${CBLOCK * 5}; channels += ${max(1, CBLOCK - 1)}) {
211*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
212*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
213*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
214*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
215*4bdc9457SAndroid Build Coastguard Worker        .width(3)
216*4bdc9457SAndroid Build Coastguard Worker        .qmin(128)
217*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
218*4bdc9457SAndroid Build Coastguard Worker    }
219*4bdc9457SAndroid Build Coastguard Worker  }
220*4bdc9457SAndroid Build Coastguard Worker
221*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, multipixel_with_qmax) {
222*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
223*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
224*4bdc9457SAndroid Build Coastguard Worker    for (size_t channels = 1; channels <= ${CBLOCK * 5}; channels += ${max(1, CBLOCK - 1)}) {
225*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
226*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
227*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
228*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
229*4bdc9457SAndroid Build Coastguard Worker        .width(3)
230*4bdc9457SAndroid Build Coastguard Worker        .qmax(128)
231*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
232*4bdc9457SAndroid Build Coastguard Worker    }
233*4bdc9457SAndroid Build Coastguard Worker  }
234*4bdc9457SAndroid Build Coastguard Worker
235*4bdc9457SAndroid Build Coastguard Worker$if DATATYPE == "qu8":
236*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, input_zero_point_only) {
237*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
238*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
239*4bdc9457SAndroid Build Coastguard Worker    for (size_t channels = 1; channels <= ${CBLOCK * 5}; channels += ${max(1, CBLOCK - 1)}) {
240*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
241*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
242*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
243*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
244*4bdc9457SAndroid Build Coastguard Worker        .width(3)
245*4bdc9457SAndroid Build Coastguard Worker        .input_zero_point(255)
246*4bdc9457SAndroid Build Coastguard Worker        .kernel_zero_point(0)
247*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
248*4bdc9457SAndroid Build Coastguard Worker    }
249*4bdc9457SAndroid Build Coastguard Worker  }
250*4bdc9457SAndroid Build Coastguard Worker
251*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, kernel_zero_point_only) {
252*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
253*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
254*4bdc9457SAndroid Build Coastguard Worker    for (size_t channels = 1; channels <= ${CBLOCK * 5}; channels += ${max(1, CBLOCK - 1)}) {
255*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
256*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
257*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
258*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
259*4bdc9457SAndroid Build Coastguard Worker        .width(3)
260*4bdc9457SAndroid Build Coastguard Worker        .input_zero_point(0)
261*4bdc9457SAndroid Build Coastguard Worker        .kernel_zero_point(255)
262*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
263*4bdc9457SAndroid Build Coastguard Worker    }
264*4bdc9457SAndroid Build Coastguard Worker  }
265*4bdc9457SAndroid Build Coastguard Worker
266*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, input_offset) {
267*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
268*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
269*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t channels = ${ADJCBLOCK + CBLOCK}; channels < ${CR * 16}; channels += ${CR * 3}) {
270*4bdc9457SAndroid Build Coastguard Worker    DWConvMicrokernelTester()
271*4bdc9457SAndroid Build Coastguard Worker      .cr(${CR})
272*4bdc9457SAndroid Build Coastguard Worker      .kr(${KR})
273*4bdc9457SAndroid Build Coastguard Worker      .channels(channels)
274*4bdc9457SAndroid Build Coastguard Worker      .input_offset(${next_prime(CR + 1) * 16})
275*4bdc9457SAndroid Build Coastguard Worker      .Test(${", ".join(TEST_ARGS)});
276*4bdc9457SAndroid Build Coastguard Worker  }
277*4bdc9457SAndroid Build Coastguard Worker}
278*4bdc9457SAndroid Build Coastguard Worker
279*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, zero) {
280*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
281*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
282*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t mz = 0; mz < ${KR}; mz++) {
283*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t channels = ${ADJCBLOCK + CBLOCK}; channels < ${CR * 16}; channels += ${CR * 3}) {
284*4bdc9457SAndroid Build Coastguard Worker      DWConvMicrokernelTester()
285*4bdc9457SAndroid Build Coastguard Worker        .cr(${CR})
286*4bdc9457SAndroid Build Coastguard Worker        .kr(${KR})
287*4bdc9457SAndroid Build Coastguard Worker        .channels(channels)
288*4bdc9457SAndroid Build Coastguard Worker        .input_offset(${next_prime(CR + 1) * 16})
289*4bdc9457SAndroid Build Coastguard Worker        .zero_index(mz)
290*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
291*4bdc9457SAndroid Build Coastguard Worker    }
292*4bdc9457SAndroid Build Coastguard Worker  }
293*4bdc9457SAndroid Build Coastguard Worker}
294*4bdc9457SAndroid Build Coastguard Worker"""
295*4bdc9457SAndroid Build Coastguard Worker
296*4bdc9457SAndroid Build Coastguard Workerdef generate_test_cases(ukernel, primary_tile, cr, kr, c_block,
297*4bdc9457SAndroid Build Coastguard Worker                        init_fn, requantization, is_pipelined, isa):
298*4bdc9457SAndroid Build Coastguard Worker  """Generates all tests cases for a DWCONV micro-kernel.
299*4bdc9457SAndroid Build Coastguard Worker
300*4bdc9457SAndroid Build Coastguard Worker  Args:
301*4bdc9457SAndroid Build Coastguard Worker    ukernel: C name of the micro-kernel function.
302*4bdc9457SAndroid Build Coastguard Worker    cr: CR parameter of the DWCONV micro-kernel.
303*4bdc9457SAndroid Build Coastguard Worker    kr: KR parameter of the DWCONV micro-kernel.
304*4bdc9457SAndroid Build Coastguard Worker    k_block: Number of C values processed per one iteration of the main loop of
305*4bdc9457SAndroid Build Coastguard Worker             the micro-kernel.
306*4bdc9457SAndroid Build Coastguard Worker    init_fn: C name of the function to initialize microkernel parameters.
307*4bdc9457SAndroid Build Coastguard Worker    requantization: name of the requantization scheme used by the microkernel.
308*4bdc9457SAndroid Build Coastguard Worker    is_pipelined: Indicates if the micro-kernel is implemented with software
309*4bdc9457SAndroid Build Coastguard Worker                  pipelining. Additional test cases are generated for software
310*4bdc9457SAndroid Build Coastguard Worker                  pipelined micro-kernels to separately test prologue + epiloque
311*4bdc9457SAndroid Build Coastguard Worker                  of the pipelined loop and iteration of the pipelined loop.
312*4bdc9457SAndroid Build Coastguard Worker    isa: instruction set required to run the micro-kernel. Generated unit test
313*4bdc9457SAndroid Build Coastguard Worker         will skip execution if the host processor doesn't support this ISA.
314*4bdc9457SAndroid Build Coastguard Worker
315*4bdc9457SAndroid Build Coastguard Worker  Returns:
316*4bdc9457SAndroid Build Coastguard Worker    Code for the test case.
317*4bdc9457SAndroid Build Coastguard Worker  """
318*4bdc9457SAndroid Build Coastguard Worker  _, test_name = ukernel.split("_", 1)
319*4bdc9457SAndroid Build Coastguard Worker  _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4)
320*4bdc9457SAndroid Build Coastguard Worker  if activation == "ukernel":
321*4bdc9457SAndroid Build Coastguard Worker    activation = "linear"
322*4bdc9457SAndroid Build Coastguard Worker  test_args = [ukernel]
323*4bdc9457SAndroid Build Coastguard Worker  if init_fn:
324*4bdc9457SAndroid Build Coastguard Worker    test_args.append(init_fn)
325*4bdc9457SAndroid Build Coastguard Worker    if requantization:
326*4bdc9457SAndroid Build Coastguard Worker      requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype)
327*4bdc9457SAndroid Build Coastguard Worker      test_args.append("xnn_%s_requantize_%s" %
328*4bdc9457SAndroid Build Coastguard Worker        (requantization_datatype, requantization))
329*4bdc9457SAndroid Build Coastguard Worker  return xngen.preprocess(DWCONV_TEST_CODE, {
330*4bdc9457SAndroid Build Coastguard Worker      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
331*4bdc9457SAndroid Build Coastguard Worker      "TEST_ARGS": test_args,
332*4bdc9457SAndroid Build Coastguard Worker      "UKERNEL_TYPE": ukernel_type.upper(),
333*4bdc9457SAndroid Build Coastguard Worker      "DATATYPE": datatype,
334*4bdc9457SAndroid Build Coastguard Worker      "ACTIVATION": activation.upper(),
335*4bdc9457SAndroid Build Coastguard Worker      "PRIMARY_TILE": primary_tile,
336*4bdc9457SAndroid Build Coastguard Worker      "CR": cr,
337*4bdc9457SAndroid Build Coastguard Worker      "KR": kr,
338*4bdc9457SAndroid Build Coastguard Worker      "CBLOCK": c_block,
339*4bdc9457SAndroid Build Coastguard Worker      "ADJCBLOCK": 2 * c_block if is_pipelined else c_block,
340*4bdc9457SAndroid Build Coastguard Worker      "IS_PIPELINED": is_pipelined,
341*4bdc9457SAndroid Build Coastguard Worker      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
342*4bdc9457SAndroid Build Coastguard Worker      "next_prime": next_prime,
343*4bdc9457SAndroid Build Coastguard Worker      "sqrt": math.sqrt,
344*4bdc9457SAndroid Build Coastguard Worker    })
345*4bdc9457SAndroid Build Coastguard Worker
346*4bdc9457SAndroid Build Coastguard Worker
347*4bdc9457SAndroid Build Coastguard Workerdef main(args):
348*4bdc9457SAndroid Build Coastguard Worker  options = parser.parse_args(args)
349*4bdc9457SAndroid Build Coastguard Worker
350*4bdc9457SAndroid Build Coastguard Worker  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
351*4bdc9457SAndroid Build Coastguard Worker    spec_yaml = yaml.safe_load(spec_file)
352*4bdc9457SAndroid Build Coastguard Worker    if not isinstance(spec_yaml, list):
353*4bdc9457SAndroid Build Coastguard Worker      raise ValueError("expected a list of micro-kernels in the spec")
354*4bdc9457SAndroid Build Coastguard Worker
355*4bdc9457SAndroid Build Coastguard Worker    tests = """\
356*4bdc9457SAndroid Build Coastguard Worker// Copyright (c) Facebook, Inc. and its affiliates.
357*4bdc9457SAndroid Build Coastguard Worker// All rights reserved.
358*4bdc9457SAndroid Build Coastguard Worker//
359*4bdc9457SAndroid Build Coastguard Worker// Copyright 2019 Google LLC
360*4bdc9457SAndroid Build Coastguard Worker//
361*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the
362*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree.
363*4bdc9457SAndroid Build Coastguard Worker//
364*4bdc9457SAndroid Build Coastguard Worker// Auto-generated file. Do not edit!
365*4bdc9457SAndroid Build Coastguard Worker//   Specification: {specification}
366*4bdc9457SAndroid Build Coastguard Worker//   Generator: {generator}
367*4bdc9457SAndroid Build Coastguard Worker
368*4bdc9457SAndroid Build Coastguard Worker
369*4bdc9457SAndroid Build Coastguard Worker#include <gtest/gtest.h>
370*4bdc9457SAndroid Build Coastguard Worker
371*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/common.h>
372*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/isa-checks.h>
373*4bdc9457SAndroid Build Coastguard Worker
374*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/dwconv.h>
375*4bdc9457SAndroid Build Coastguard Worker#include "dwconv-microkernel-tester.h"
376*4bdc9457SAndroid Build Coastguard Worker""".format(specification=options.spec, generator=sys.argv[0])
377*4bdc9457SAndroid Build Coastguard Worker
378*4bdc9457SAndroid Build Coastguard Worker    for ukernel_spec in spec_yaml:
379*4bdc9457SAndroid Build Coastguard Worker      name = ukernel_spec["name"]
380*4bdc9457SAndroid Build Coastguard Worker      init_fn = ukernel_spec.get("init")
381*4bdc9457SAndroid Build Coastguard Worker      pipelined = bool(ukernel_spec.get("pipelined", False))
382*4bdc9457SAndroid Build Coastguard Worker      assembly = bool(ukernel_spec.get("assembly", False))
383*4bdc9457SAndroid Build Coastguard Worker      primary_tile, cr, kr, requantization, arch, isa = split_ukernel_name(name)
384*4bdc9457SAndroid Build Coastguard Worker
385*4bdc9457SAndroid Build Coastguard Worker      # specification can override architecture
386*4bdc9457SAndroid Build Coastguard Worker      arch = ukernel_spec.get("arch", arch)
387*4bdc9457SAndroid Build Coastguard Worker
388*4bdc9457SAndroid Build Coastguard Worker      test_case = generate_test_cases(
389*4bdc9457SAndroid Build Coastguard Worker        name, primary_tile, cr, kr, cr, init_fn, requantization, pipelined, isa)
390*4bdc9457SAndroid Build Coastguard Worker      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa, assembly)
391*4bdc9457SAndroid Build Coastguard Worker
392*4bdc9457SAndroid Build Coastguard Worker    txt_changed = True
393*4bdc9457SAndroid Build Coastguard Worker    if os.path.exists(options.output):
394*4bdc9457SAndroid Build Coastguard Worker      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
395*4bdc9457SAndroid Build Coastguard Worker        txt_changed = output_file.read() != tests
396*4bdc9457SAndroid Build Coastguard Worker
397*4bdc9457SAndroid Build Coastguard Worker    if txt_changed:
398*4bdc9457SAndroid Build Coastguard Worker      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
399*4bdc9457SAndroid Build Coastguard Worker        output_file.write(tests)
400*4bdc9457SAndroid Build Coastguard Worker
401*4bdc9457SAndroid Build Coastguard Worker
402*4bdc9457SAndroid Build Coastguard Workerif __name__ == "__main__":
403*4bdc9457SAndroid Build Coastguard Worker  main(sys.argv[1:])
404