xref: /aosp_15_r20/external/XNNPACK/tools/generate-spmm-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker#!/usr/bin/env python
2*4bdc9457SAndroid Build Coastguard Worker# Copyright 2019 Google LLC
3*4bdc9457SAndroid Build Coastguard Worker#
4*4bdc9457SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*4bdc9457SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*4bdc9457SAndroid Build Coastguard Worker
7*4bdc9457SAndroid Build Coastguard Workerimport argparse
8*4bdc9457SAndroid Build Coastguard Workerimport bisect
9*4bdc9457SAndroid Build Coastguard Workerimport codecs
10*4bdc9457SAndroid Build Coastguard Workerimport os
11*4bdc9457SAndroid Build Coastguard Workerimport sys
12*4bdc9457SAndroid Build Coastguard Workerimport yaml
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Workersys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15*4bdc9457SAndroid Build Coastguard Workerfrom primes import next_prime
16*4bdc9457SAndroid Build Coastguard Workerimport xngen
17*4bdc9457SAndroid Build Coastguard Workerimport xnncommon
18*4bdc9457SAndroid Build Coastguard Worker
19*4bdc9457SAndroid Build Coastguard Worker
20*4bdc9457SAndroid Build Coastguard Workerparser = argparse.ArgumentParser(description='XNNPACK generator')
21*4bdc9457SAndroid Build Coastguard Workerparser.add_argument("-s", "--spec", metavar="FILE", required=True,
22*4bdc9457SAndroid Build Coastguard Worker                    help="Spec (YAML) file")
23*4bdc9457SAndroid Build Coastguard Workerparser.add_argument("-o", "--output", metavar="FILE", required=True,
24*4bdc9457SAndroid Build Coastguard Worker                    help='Output (C++ source) file')
25*4bdc9457SAndroid Build Coastguard Workerparser.set_defaults(defines=list())
26*4bdc9457SAndroid Build Coastguard Worker
27*4bdc9457SAndroid Build Coastguard Worker
28*4bdc9457SAndroid Build Coastguard Workerdef split_ukernel_name(name):
29*4bdc9457SAndroid Build Coastguard Worker  common_name, target_name = name.split("__", 1)
30*4bdc9457SAndroid Build Coastguard Worker  common_parts = common_name.split("_")
31*4bdc9457SAndroid Build Coastguard Worker  param_spec = common_parts[-1]
32*4bdc9457SAndroid Build Coastguard Worker  mr, nr = map(int, param_spec.split("x"))
33*4bdc9457SAndroid Build Coastguard Worker  arch, isa = xnncommon.parse_target_name(target_name)
34*4bdc9457SAndroid Build Coastguard Worker  return mr, nr, arch, isa
35*4bdc9457SAndroid Build Coastguard Worker
36*4bdc9457SAndroid Build Coastguard Worker
37*4bdc9457SAndroid Build Coastguard WorkerTEST_TEMPLATE = """\
38*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, k_eq_${KBLOCK}) {
39*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
40*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
41*4bdc9457SAndroid Build Coastguard Worker  SpMMMicrokernelTester()
42*4bdc9457SAndroid Build Coastguard Worker    .mr(${MR})
43*4bdc9457SAndroid Build Coastguard Worker    .nr(${NR})
44*4bdc9457SAndroid Build Coastguard Worker    .m(${MR})
45*4bdc9457SAndroid Build Coastguard Worker    .n(${NR})
46*4bdc9457SAndroid Build Coastguard Worker    .k(${KBLOCK})
47*4bdc9457SAndroid Build Coastguard Worker    .sparsity(0.0f)
48*4bdc9457SAndroid Build Coastguard Worker    .Test(${", ".join(TEST_ARGS)});
49*4bdc9457SAndroid Build Coastguard Worker}
50*4bdc9457SAndroid Build Coastguard Worker
51*4bdc9457SAndroid Build Coastguard Worker$if NR > 1:
52*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
53*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
54*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
55*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t n = 1; n <= ${NR}; n++) {
56*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
57*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
58*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
59*4bdc9457SAndroid Build Coastguard Worker        .m(${MR})
60*4bdc9457SAndroid Build Coastguard Worker        .n(n)
61*4bdc9457SAndroid Build Coastguard Worker        .k(${KBLOCK})
62*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.0f)
63*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
64*4bdc9457SAndroid Build Coastguard Worker    }
65*4bdc9457SAndroid Build Coastguard Worker  }
66*4bdc9457SAndroid Build Coastguard Worker
67*4bdc9457SAndroid Build Coastguard Worker$if IS_PIPELINED:
68*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
69*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
70*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
71*4bdc9457SAndroid Build Coastguard Worker    SpMMMicrokernelTester()
72*4bdc9457SAndroid Build Coastguard Worker      .mr(${MR})
73*4bdc9457SAndroid Build Coastguard Worker      .nr(${NR})
74*4bdc9457SAndroid Build Coastguard Worker      .m(${MR})
75*4bdc9457SAndroid Build Coastguard Worker      .n(${NR})
76*4bdc9457SAndroid Build Coastguard Worker      .k(${KBLOCK * 2})
77*4bdc9457SAndroid Build Coastguard Worker      .sparsity(0.0f)
78*4bdc9457SAndroid Build Coastguard Worker      .Test(${", ".join(TEST_ARGS)});
79*4bdc9457SAndroid Build Coastguard Worker  }
80*4bdc9457SAndroid Build Coastguard Worker
81*4bdc9457SAndroid Build Coastguard Worker  $if NR > 1:
82*4bdc9457SAndroid Build Coastguard Worker    TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
83*4bdc9457SAndroid Build Coastguard Worker      $if ISA_CHECK:
84*4bdc9457SAndroid Build Coastguard Worker        ${ISA_CHECK};
85*4bdc9457SAndroid Build Coastguard Worker      for (uint32_t n = 1; n <= ${NR}; n++) {
86*4bdc9457SAndroid Build Coastguard Worker        SpMMMicrokernelTester()
87*4bdc9457SAndroid Build Coastguard Worker          .mr(${MR})
88*4bdc9457SAndroid Build Coastguard Worker          .nr(${NR})
89*4bdc9457SAndroid Build Coastguard Worker          .m(${MR})
90*4bdc9457SAndroid Build Coastguard Worker          .n(n)
91*4bdc9457SAndroid Build Coastguard Worker          .k(${KBLOCK * 2})
92*4bdc9457SAndroid Build Coastguard Worker          .sparsity(0.0f)
93*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
94*4bdc9457SAndroid Build Coastguard Worker      }
95*4bdc9457SAndroid Build Coastguard Worker    }
96*4bdc9457SAndroid Build Coastguard Worker
97*4bdc9457SAndroid Build Coastguard Worker$if KBLOCK > 1:
98*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
99*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
100*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
101*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
102*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
103*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
104*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
105*4bdc9457SAndroid Build Coastguard Worker        .m(${MR})
106*4bdc9457SAndroid Build Coastguard Worker        .n(${NR})
107*4bdc9457SAndroid Build Coastguard Worker        .k(k)
108*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.0f)
109*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
110*4bdc9457SAndroid Build Coastguard Worker    }
111*4bdc9457SAndroid Build Coastguard Worker  }
112*4bdc9457SAndroid Build Coastguard Worker
113*4bdc9457SAndroid Build Coastguard Worker  $if NR > 1:
114*4bdc9457SAndroid Build Coastguard Worker    TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
115*4bdc9457SAndroid Build Coastguard Worker      $if ISA_CHECK:
116*4bdc9457SAndroid Build Coastguard Worker        ${ISA_CHECK};
117*4bdc9457SAndroid Build Coastguard Worker      for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
118*4bdc9457SAndroid Build Coastguard Worker        for (uint32_t n = 1; n <= ${NR}; n++) {
119*4bdc9457SAndroid Build Coastguard Worker          SpMMMicrokernelTester()
120*4bdc9457SAndroid Build Coastguard Worker            .mr(${MR})
121*4bdc9457SAndroid Build Coastguard Worker            .nr(${NR})
122*4bdc9457SAndroid Build Coastguard Worker            .m(${MR})
123*4bdc9457SAndroid Build Coastguard Worker            .n(n)
124*4bdc9457SAndroid Build Coastguard Worker            .k(k)
125*4bdc9457SAndroid Build Coastguard Worker            .sparsity(0.0f)
126*4bdc9457SAndroid Build Coastguard Worker            .Test(${", ".join(TEST_ARGS)});
127*4bdc9457SAndroid Build Coastguard Worker        }
128*4bdc9457SAndroid Build Coastguard Worker      }
129*4bdc9457SAndroid Build Coastguard Worker    }
130*4bdc9457SAndroid Build Coastguard Worker
131*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
132*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
133*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
134*4bdc9457SAndroid Build Coastguard Worker  for (size_t k = ${ADJKBLOCK + 1}; k < ${KBLOCK * 10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
135*4bdc9457SAndroid Build Coastguard Worker    SpMMMicrokernelTester()
136*4bdc9457SAndroid Build Coastguard Worker      .mr(${MR})
137*4bdc9457SAndroid Build Coastguard Worker      .nr(${NR})
138*4bdc9457SAndroid Build Coastguard Worker      .m(${MR})
139*4bdc9457SAndroid Build Coastguard Worker      .n(${NR})
140*4bdc9457SAndroid Build Coastguard Worker      .k(k)
141*4bdc9457SAndroid Build Coastguard Worker      .sparsity(0.0f)
142*4bdc9457SAndroid Build Coastguard Worker      .Test(${", ".join(TEST_ARGS)});
143*4bdc9457SAndroid Build Coastguard Worker  }
144*4bdc9457SAndroid Build Coastguard Worker}
145*4bdc9457SAndroid Build Coastguard Worker
146*4bdc9457SAndroid Build Coastguard Worker$if NR > 1:
147*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) {
148*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
149*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
150*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
151*4bdc9457SAndroid Build Coastguard Worker      for (uint32_t n = 1; n <= ${NR}; n++) {
152*4bdc9457SAndroid Build Coastguard Worker        SpMMMicrokernelTester()
153*4bdc9457SAndroid Build Coastguard Worker          .mr(${MR})
154*4bdc9457SAndroid Build Coastguard Worker          .nr(${NR})
155*4bdc9457SAndroid Build Coastguard Worker          .m(${MR})
156*4bdc9457SAndroid Build Coastguard Worker          .n(n)
157*4bdc9457SAndroid Build Coastguard Worker          .k(k)
158*4bdc9457SAndroid Build Coastguard Worker          .sparsity(0.0f)
159*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
160*4bdc9457SAndroid Build Coastguard Worker      }
161*4bdc9457SAndroid Build Coastguard Worker    }
162*4bdc9457SAndroid Build Coastguard Worker  }
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker$if KBLOCK > 1:
165*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, k_div_${KBLOCK}) {
166*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
167*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
168*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
169*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
170*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
171*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
172*4bdc9457SAndroid Build Coastguard Worker        .m(${MR})
173*4bdc9457SAndroid Build Coastguard Worker        .n(${NR})
174*4bdc9457SAndroid Build Coastguard Worker        .k(k)
175*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.0f)
176*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
177*4bdc9457SAndroid Build Coastguard Worker    }
178*4bdc9457SAndroid Build Coastguard Worker  }
179*4bdc9457SAndroid Build Coastguard Worker
180*4bdc9457SAndroid Build Coastguard Worker  $if NR > 1:
181*4bdc9457SAndroid Build Coastguard Worker    TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
182*4bdc9457SAndroid Build Coastguard Worker      $if ISA_CHECK:
183*4bdc9457SAndroid Build Coastguard Worker        ${ISA_CHECK};
184*4bdc9457SAndroid Build Coastguard Worker      for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
185*4bdc9457SAndroid Build Coastguard Worker        for (uint32_t n = 1; n <= ${NR}; n++) {
186*4bdc9457SAndroid Build Coastguard Worker          SpMMMicrokernelTester()
187*4bdc9457SAndroid Build Coastguard Worker            .mr(${MR})
188*4bdc9457SAndroid Build Coastguard Worker            .nr(${NR})
189*4bdc9457SAndroid Build Coastguard Worker            .m(${MR})
190*4bdc9457SAndroid Build Coastguard Worker            .n(n)
191*4bdc9457SAndroid Build Coastguard Worker            .k(k)
192*4bdc9457SAndroid Build Coastguard Worker            .sparsity(0.0f)
193*4bdc9457SAndroid Build Coastguard Worker            .Test(${", ".join(TEST_ARGS)});
194*4bdc9457SAndroid Build Coastguard Worker        }
195*4bdc9457SAndroid Build Coastguard Worker      }
196*4bdc9457SAndroid Build Coastguard Worker    }
197*4bdc9457SAndroid Build Coastguard Worker
198*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, n_gt_${NR}) {
199*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
200*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
201*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t n = ${NR + 1}; n < ${max(10, NR * 2)}; n++) {
202*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
203*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
204*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
205*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
206*4bdc9457SAndroid Build Coastguard Worker        .m(${MR})
207*4bdc9457SAndroid Build Coastguard Worker        .n(n)
208*4bdc9457SAndroid Build Coastguard Worker        .k(k)
209*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.0f)
210*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
211*4bdc9457SAndroid Build Coastguard Worker    }
212*4bdc9457SAndroid Build Coastguard Worker  }
213*4bdc9457SAndroid Build Coastguard Worker}
214*4bdc9457SAndroid Build Coastguard Worker
215*4bdc9457SAndroid Build Coastguard Worker$if NR > 1:
216*4bdc9457SAndroid Build Coastguard Worker  TEST(${TEST_NAME}, n_div_${NR}) {
217*4bdc9457SAndroid Build Coastguard Worker    $if ISA_CHECK:
218*4bdc9457SAndroid Build Coastguard Worker      ${ISA_CHECK};
219*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
220*4bdc9457SAndroid Build Coastguard Worker      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
221*4bdc9457SAndroid Build Coastguard Worker        SpMMMicrokernelTester()
222*4bdc9457SAndroid Build Coastguard Worker          .mr(${MR})
223*4bdc9457SAndroid Build Coastguard Worker          .nr(${NR})
224*4bdc9457SAndroid Build Coastguard Worker          .m(${MR})
225*4bdc9457SAndroid Build Coastguard Worker          .n(n)
226*4bdc9457SAndroid Build Coastguard Worker          .k(k)
227*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
228*4bdc9457SAndroid Build Coastguard Worker      }
229*4bdc9457SAndroid Build Coastguard Worker    }
230*4bdc9457SAndroid Build Coastguard Worker  }
231*4bdc9457SAndroid Build Coastguard Worker
232*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, m_lt_${MR}) {
233*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
234*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
235*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t m = ${1}; m < ${MR}; m++) {
236*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
237*4bdc9457SAndroid Build Coastguard Worker      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
238*4bdc9457SAndroid Build Coastguard Worker        SpMMMicrokernelTester()
239*4bdc9457SAndroid Build Coastguard Worker          .mr(${MR})
240*4bdc9457SAndroid Build Coastguard Worker          .nr(${NR})
241*4bdc9457SAndroid Build Coastguard Worker          .m(m)
242*4bdc9457SAndroid Build Coastguard Worker          .n(n)
243*4bdc9457SAndroid Build Coastguard Worker          .k(k)
244*4bdc9457SAndroid Build Coastguard Worker          .sparsity(0.0f)
245*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
246*4bdc9457SAndroid Build Coastguard Worker      }
247*4bdc9457SAndroid Build Coastguard Worker    }
248*4bdc9457SAndroid Build Coastguard Worker  }
249*4bdc9457SAndroid Build Coastguard Worker}
250*4bdc9457SAndroid Build Coastguard Worker
251*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, m_div_${MR}) {
252*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
253*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
254*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t m = ${MR * 2}; m <= ${MR * 3}; m += ${MR}) {
255*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
256*4bdc9457SAndroid Build Coastguard Worker      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
257*4bdc9457SAndroid Build Coastguard Worker        SpMMMicrokernelTester()
258*4bdc9457SAndroid Build Coastguard Worker          .mr(${MR})
259*4bdc9457SAndroid Build Coastguard Worker          .nr(${NR})
260*4bdc9457SAndroid Build Coastguard Worker          .m(m)
261*4bdc9457SAndroid Build Coastguard Worker          .n(n)
262*4bdc9457SAndroid Build Coastguard Worker          .k(k)
263*4bdc9457SAndroid Build Coastguard Worker          .sparsity(0.0f)
264*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
265*4bdc9457SAndroid Build Coastguard Worker      }
266*4bdc9457SAndroid Build Coastguard Worker    }
267*4bdc9457SAndroid Build Coastguard Worker  }
268*4bdc9457SAndroid Build Coastguard Worker}
269*4bdc9457SAndroid Build Coastguard Worker
270*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, m_gt_${MR}) {
271*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
272*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
273*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t m = ${MR + 1}; m < ${MR * 2}; m++) {
274*4bdc9457SAndroid Build Coastguard Worker    for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
275*4bdc9457SAndroid Build Coastguard Worker      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
276*4bdc9457SAndroid Build Coastguard Worker        SpMMMicrokernelTester()
277*4bdc9457SAndroid Build Coastguard Worker          .mr(${MR})
278*4bdc9457SAndroid Build Coastguard Worker          .nr(${NR})
279*4bdc9457SAndroid Build Coastguard Worker          .m(m)
280*4bdc9457SAndroid Build Coastguard Worker          .n(n)
281*4bdc9457SAndroid Build Coastguard Worker          .k(k)
282*4bdc9457SAndroid Build Coastguard Worker          .sparsity(0.0f)
283*4bdc9457SAndroid Build Coastguard Worker          .Test(${", ".join(TEST_ARGS)});
284*4bdc9457SAndroid Build Coastguard Worker      }
285*4bdc9457SAndroid Build Coastguard Worker    }
286*4bdc9457SAndroid Build Coastguard Worker  }
287*4bdc9457SAndroid Build Coastguard Worker}
288*4bdc9457SAndroid Build Coastguard Worker
289*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, output_stride) {
290*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
291*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
292*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
293*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
294*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
295*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
296*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
297*4bdc9457SAndroid Build Coastguard Worker        .m(${MR * 2})
298*4bdc9457SAndroid Build Coastguard Worker        .n(n)
299*4bdc9457SAndroid Build Coastguard Worker        .k(k)
300*4bdc9457SAndroid Build Coastguard Worker        .output_stride(${next_prime(MR * 2 + 1)})
301*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.0f)
302*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
303*4bdc9457SAndroid Build Coastguard Worker    }
304*4bdc9457SAndroid Build Coastguard Worker  }
305*4bdc9457SAndroid Build Coastguard Worker}
306*4bdc9457SAndroid Build Coastguard Worker
307*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, qmin) {
308*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
309*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
310*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
311*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
312*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
313*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
314*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
315*4bdc9457SAndroid Build Coastguard Worker        .m(${MR * 2})
316*4bdc9457SAndroid Build Coastguard Worker        .n(n)
317*4bdc9457SAndroid Build Coastguard Worker        .k(k)
318*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.0f)
319*4bdc9457SAndroid Build Coastguard Worker        .qmin(128)
320*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
321*4bdc9457SAndroid Build Coastguard Worker    }
322*4bdc9457SAndroid Build Coastguard Worker  }
323*4bdc9457SAndroid Build Coastguard Worker}
324*4bdc9457SAndroid Build Coastguard Worker
325*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, qmax) {
326*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
327*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
328*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
329*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
330*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
331*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
332*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
333*4bdc9457SAndroid Build Coastguard Worker        .m(${MR * 2})
334*4bdc9457SAndroid Build Coastguard Worker        .n(n)
335*4bdc9457SAndroid Build Coastguard Worker        .k(k)
336*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.0f)
337*4bdc9457SAndroid Build Coastguard Worker        .qmax(128)
338*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
339*4bdc9457SAndroid Build Coastguard Worker    }
340*4bdc9457SAndroid Build Coastguard Worker  }
341*4bdc9457SAndroid Build Coastguard Worker}
342*4bdc9457SAndroid Build Coastguard Worker
343*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, half_sparse) {
344*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
345*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
346*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
347*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
348*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
349*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
350*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
351*4bdc9457SAndroid Build Coastguard Worker        .m(${MR * 2})
352*4bdc9457SAndroid Build Coastguard Worker        .n(n)
353*4bdc9457SAndroid Build Coastguard Worker        .k(k)
354*4bdc9457SAndroid Build Coastguard Worker        .sparsity(0.5f)
355*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
356*4bdc9457SAndroid Build Coastguard Worker    }
357*4bdc9457SAndroid Build Coastguard Worker  }
358*4bdc9457SAndroid Build Coastguard Worker}
359*4bdc9457SAndroid Build Coastguard Worker
360*4bdc9457SAndroid Build Coastguard WorkerTEST(${TEST_NAME}, zero_weights) {
361*4bdc9457SAndroid Build Coastguard Worker  $if ISA_CHECK:
362*4bdc9457SAndroid Build Coastguard Worker    ${ISA_CHECK};
363*4bdc9457SAndroid Build Coastguard Worker  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
364*4bdc9457SAndroid Build Coastguard Worker    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
365*4bdc9457SAndroid Build Coastguard Worker      SpMMMicrokernelTester()
366*4bdc9457SAndroid Build Coastguard Worker        .mr(${MR})
367*4bdc9457SAndroid Build Coastguard Worker        .nr(${NR})
368*4bdc9457SAndroid Build Coastguard Worker        .m(${MR * 2})
369*4bdc9457SAndroid Build Coastguard Worker        .n(n)
370*4bdc9457SAndroid Build Coastguard Worker        .k(k)
371*4bdc9457SAndroid Build Coastguard Worker        .sparsity(1.0f)
372*4bdc9457SAndroid Build Coastguard Worker        .Test(${", ".join(TEST_ARGS)});
373*4bdc9457SAndroid Build Coastguard Worker    }
374*4bdc9457SAndroid Build Coastguard Worker  }
375*4bdc9457SAndroid Build Coastguard Worker}
376*4bdc9457SAndroid Build Coastguard Worker"""
377*4bdc9457SAndroid Build Coastguard Worker
378*4bdc9457SAndroid Build Coastguard Worker
379*4bdc9457SAndroid Build Coastguard Workerdef generate_test_cases(ukernel, init_fn, mr, nr, k_block, is_pipelined, isa):
380*4bdc9457SAndroid Build Coastguard Worker  """Generates all tests cases for a GEMM micro-kernel.
381*4bdc9457SAndroid Build Coastguard Worker
382*4bdc9457SAndroid Build Coastguard Worker  Args:
383*4bdc9457SAndroid Build Coastguard Worker    ukernel: C name of the micro-kernel function.
384*4bdc9457SAndroid Build Coastguard Worker    init_fn: C name of the function to initialize microkernel parameters.
385*4bdc9457SAndroid Build Coastguard Worker    mr: MR parameter of the GEMM micro-kernel.
386*4bdc9457SAndroid Build Coastguard Worker    nr: NR parameter of the GEMM micro-kernel.
387*4bdc9457SAndroid Build Coastguard Worker    k_block: Number of K values processed per one iteration of the main loop of
388*4bdc9457SAndroid Build Coastguard Worker             the micro-kernel.
389*4bdc9457SAndroid Build Coastguard Worker    is_pipelined: Indicates if the micro-kernel is implemented with software
390*4bdc9457SAndroid Build Coastguard Worker                  pipelining. Additional test cases are generated for software
391*4bdc9457SAndroid Build Coastguard Worker                  pipelined micro-kernels to separately test prologue + epiloque
392*4bdc9457SAndroid Build Coastguard Worker                  of the pipelined loop and iteration of the pipelined loop.
393*4bdc9457SAndroid Build Coastguard Worker    isa: instruction set required to run the micro-kernel. Generated unit test
394*4bdc9457SAndroid Build Coastguard Worker         will skip execution if the host processor doesn't support this ISA.
395*4bdc9457SAndroid Build Coastguard Worker
396*4bdc9457SAndroid Build Coastguard Worker  Returns:
397*4bdc9457SAndroid Build Coastguard Worker    Code for the test case.
398*4bdc9457SAndroid Build Coastguard Worker  """
399*4bdc9457SAndroid Build Coastguard Worker  _, test_name = ukernel.split("_", 1)
400*4bdc9457SAndroid Build Coastguard Worker  _, datatype, ukernel_type, _ = ukernel.split("_", 3)
401*4bdc9457SAndroid Build Coastguard Worker  test_args = [ukernel, init_fn]
402*4bdc9457SAndroid Build Coastguard Worker  return xngen.preprocess(TEST_TEMPLATE, {
403*4bdc9457SAndroid Build Coastguard Worker      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
404*4bdc9457SAndroid Build Coastguard Worker      "TEST_ARGS": test_args,
405*4bdc9457SAndroid Build Coastguard Worker      "UKERNEL_TYPE": ukernel_type.upper(),
406*4bdc9457SAndroid Build Coastguard Worker      "DATATYPE": datatype,
407*4bdc9457SAndroid Build Coastguard Worker      "MR": mr,
408*4bdc9457SAndroid Build Coastguard Worker      "NR": nr,
409*4bdc9457SAndroid Build Coastguard Worker      "KBLOCK": k_block,
410*4bdc9457SAndroid Build Coastguard Worker      "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
411*4bdc9457SAndroid Build Coastguard Worker      "IS_PIPELINED": is_pipelined,
412*4bdc9457SAndroid Build Coastguard Worker      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
413*4bdc9457SAndroid Build Coastguard Worker      "next_prime": next_prime,
414*4bdc9457SAndroid Build Coastguard Worker    })
415*4bdc9457SAndroid Build Coastguard Worker
416*4bdc9457SAndroid Build Coastguard Worker
417*4bdc9457SAndroid Build Coastguard Workerdef main(args):
418*4bdc9457SAndroid Build Coastguard Worker  options = parser.parse_args(args)
419*4bdc9457SAndroid Build Coastguard Worker
420*4bdc9457SAndroid Build Coastguard Worker  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
421*4bdc9457SAndroid Build Coastguard Worker    spec_yaml = yaml.safe_load(spec_file)
422*4bdc9457SAndroid Build Coastguard Worker    if not isinstance(spec_yaml, list):
423*4bdc9457SAndroid Build Coastguard Worker      raise ValueError("expected a list of micro-kernels in the spec")
424*4bdc9457SAndroid Build Coastguard Worker
425*4bdc9457SAndroid Build Coastguard Worker    tests = """\
426*4bdc9457SAndroid Build Coastguard Worker// Copyright 2019 Google LLC
427*4bdc9457SAndroid Build Coastguard Worker//
428*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the
429*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree.
430*4bdc9457SAndroid Build Coastguard Worker//
431*4bdc9457SAndroid Build Coastguard Worker// Auto-generated file. Do not edit!
432*4bdc9457SAndroid Build Coastguard Worker//   Specification: {specification}
433*4bdc9457SAndroid Build Coastguard Worker//   Generator: {generator}
434*4bdc9457SAndroid Build Coastguard Worker
435*4bdc9457SAndroid Build Coastguard Worker
436*4bdc9457SAndroid Build Coastguard Worker#include <gtest/gtest.h>
437*4bdc9457SAndroid Build Coastguard Worker
438*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/common.h>
439*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/isa-checks.h>
440*4bdc9457SAndroid Build Coastguard Worker
441*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/spmm.h>
442*4bdc9457SAndroid Build Coastguard Worker#include "spmm-microkernel-tester.h"
443*4bdc9457SAndroid Build Coastguard Worker""".format(specification=options.spec, generator=sys.argv[0])
444*4bdc9457SAndroid Build Coastguard Worker
445*4bdc9457SAndroid Build Coastguard Worker    for ukernel_spec in spec_yaml:
446*4bdc9457SAndroid Build Coastguard Worker      name = ukernel_spec["name"]
447*4bdc9457SAndroid Build Coastguard Worker      init_fn = ukernel_spec["init"]
448*4bdc9457SAndroid Build Coastguard Worker      k_block = int(ukernel_spec["k-block"])
449*4bdc9457SAndroid Build Coastguard Worker      pipelined = bool(ukernel_spec.get("pipelined", False))
450*4bdc9457SAndroid Build Coastguard Worker      mr, nr, arch, isa = split_ukernel_name(name)
451*4bdc9457SAndroid Build Coastguard Worker
452*4bdc9457SAndroid Build Coastguard Worker      # specification can override architecture
453*4bdc9457SAndroid Build Coastguard Worker      arch = ukernel_spec.get("arch", arch)
454*4bdc9457SAndroid Build Coastguard Worker
455*4bdc9457SAndroid Build Coastguard Worker      test_case = generate_test_cases(name, init_fn, mr, nr, k_block,
456*4bdc9457SAndroid Build Coastguard Worker                                      pipelined, isa)
457*4bdc9457SAndroid Build Coastguard Worker      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
458*4bdc9457SAndroid Build Coastguard Worker
459*4bdc9457SAndroid Build Coastguard Worker    txt_changed = True
460*4bdc9457SAndroid Build Coastguard Worker    if os.path.exists(options.output):
461*4bdc9457SAndroid Build Coastguard Worker      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
462*4bdc9457SAndroid Build Coastguard Worker        txt_changed = output_file.read() != tests
463*4bdc9457SAndroid Build Coastguard Worker
464*4bdc9457SAndroid Build Coastguard Worker    if txt_changed:
465*4bdc9457SAndroid Build Coastguard Worker      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
466*4bdc9457SAndroid Build Coastguard Worker        output_file.write(tests)
467*4bdc9457SAndroid Build Coastguard Worker
468*4bdc9457SAndroid Build Coastguard Worker
469*4bdc9457SAndroid Build Coastguard Workerif __name__ == "__main__":
470*4bdc9457SAndroid Build Coastguard Worker  main(sys.argv[1:])
471