xref: /aosp_15_r20/external/XNNPACK/tools/generate-gemm-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import collections
11import os
12import sys
13import yaml
14import zlib
15
16sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
17from primes import next_prime
18import xngen
19import xnncommon
20
21parser = argparse.ArgumentParser(description="XNNPACK generator")
22parser.add_argument(
23    "-s", "--spec", metavar="FILE", required=True, help="Spec (YAML) file")
24parser.add_argument(
25    "-o",
26    "--output",
27    action="append",
28    metavar="FILE",
29    required=True,
30    help="Output (C++ source) file(s)")
31parser.set_defaults(defines=list())
32
33
34def split_ukernel_name(name):
35  common_name, target_name = name.split("__", 1)
36  common_parts = common_name.split("_")
37  xw = "gemm_xw_" in common_name
38  param_spec = common_parts[-1]
39  if param_spec.startswith('upto'):
40    param_spec = param_spec[len('upto'):]
41  if "s" in param_spec:
42    param_spec, sr = param_spec.split("s", 1)
43    sr = int(sr)
44  else:
45    sr = 1
46  if "c" in param_spec:
47    param_spec, kr = param_spec.split("c", 1)
48    kr = int(kr)
49  else:
50    kr = 1
51  mr, nr = map(int, param_spec.split("x"))
52  arch, isa = xnncommon.parse_target_name(target_name)
53
54  requantization = common_parts[-3]
55  if requantization not in ["fp32", "rndnu"]:
56    requantization = None
57
58  return mr, nr, kr, sr, xw, requantization, arch, isa
59
60
61GEMM_TEST_CODE = """\
62TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
63  $if ISA_CHECK:
64    ${ISA_CHECK};
65  GemmMicrokernelTester()
66    $if EXTENDED_WEIGHTS:
67      .extended_weights(true)
68    .mr(${MR})
69    .nr(${NR})
70    .kr(${KR})
71    .sr(${SR})
72    .m(${MR})
73    .n(${NR})
74    .k(${KBLOCK})
75    .Test(${", ".join(TEST_ARGS)});
76}
77
78TEST(${TEST_NAME}, strided_cn) {
79  $if ISA_CHECK:
80    ${ISA_CHECK};
81  GemmMicrokernelTester()
82    $if EXTENDED_WEIGHTS:
83      .extended_weights(true)
84    .mr(${MR})
85    .nr(${NR})
86    .kr(${KR})
87    .sr(${SR})
88    .m(${MR})
89    .n(${NR})
90    .k(${KBLOCK})
91    .cn_stride(${next_prime(NR + 1)})
92    .Test(${", ".join(TEST_ARGS)});
93}
94
95$if UKERNEL_TYPE != "IGEMM":
96  TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) {
97    $if ISA_CHECK:
98      ${ISA_CHECK};
99    GemmMicrokernelTester()
100      $if EXTENDED_WEIGHTS:
101        .extended_weights(true)
102      .mr(${MR})
103      .nr(${NR})
104      .kr(${KR})
105      .sr(${SR})
106      .m(${MR})
107      .n(${NR})
108      .k(${KBLOCK})
109      .a_stride(${next_prime(KBLOCK + 1)})
110      .Test(${", ".join(TEST_ARGS)});
111  }
112
113TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
114  $if ISA_CHECK:
115    ${ISA_CHECK};
116  for (uint32_t n = 1; n <= ${NR}; n++) {
117    for (uint32_t m = 1; m <= ${MR}; m++) {
118      GemmMicrokernelTester()
119        $if EXTENDED_WEIGHTS:
120          .extended_weights(true)
121        .mr(${MR})
122        .nr(${NR})
123        .kr(${KR})
124        .sr(${SR})
125        .m(m)
126        .n(n)
127        .k(${KBLOCK})
128        .iterations(1)
129        .Test(${", ".join(TEST_ARGS)});
130    }
131  }
132}
133
134TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) {
135  $if ISA_CHECK:
136    ${ISA_CHECK};
137  for (uint32_t m = 1; m <= ${MR}; m++) {
138    GemmMicrokernelTester()
139      $if EXTENDED_WEIGHTS:
140        .extended_weights(true)
141      .mr(${MR})
142      .nr(${NR})
143      .kr(${KR})
144      .sr(${SR})
145      .m(m)
146      .n(${NR})
147      .k(${KBLOCK})
148      .iterations(1)
149      .Test(${", ".join(TEST_ARGS)});
150  }
151}
152
153
154TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) {
155  $if ISA_CHECK:
156    ${ISA_CHECK};
157  for (uint32_t n = 1; n <= ${NR}; n++) {
158    GemmMicrokernelTester()
159      $if EXTENDED_WEIGHTS:
160        .extended_weights(true)
161      .mr(${MR})
162      .nr(${NR})
163      .kr(${KR})
164      .sr(${SR})
165      .m(${MR})
166      .n(n)
167      .k(${KBLOCK})
168      .iterations(1)
169      .Test(${", ".join(TEST_ARGS)});
170  }
171}
172
173$if IS_PIPELINED:
174  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
175    $if ISA_CHECK:
176      ${ISA_CHECK};
177    GemmMicrokernelTester()
178      $if EXTENDED_WEIGHTS:
179        .extended_weights(true)
180      .mr(${MR})
181      .nr(${NR})
182      .kr(${KR})
183      .sr(${SR})
184      .m(${MR})
185      .n(${NR})
186      .k(${KBLOCK * 2})
187      .Test(${", ".join(TEST_ARGS)});
188  }
189
190  $if UKERNEL_TYPE != "IGEMM":
191    TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) {
192      $if ISA_CHECK:
193        ${ISA_CHECK};
194      GemmMicrokernelTester()
195        $if EXTENDED_WEIGHTS:
196          .extended_weights(true)
197        .mr(${MR})
198        .nr(${NR})
199        .kr(${KR})
200        .sr(${SR})
201        .m(${MR})
202        .n(${NR})
203        .k(${KBLOCK * 2})
204        .a_stride(${next_prime(KBLOCK * 2 + 1)})
205        .Test(${", ".join(TEST_ARGS)});
206    }
207
208  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
209    $if ISA_CHECK:
210      ${ISA_CHECK};
211    for (uint32_t n = 1; n <= ${NR}; n++) {
212      for (uint32_t m = 1; m <= ${MR}; m++) {
213        GemmMicrokernelTester()
214          $if EXTENDED_WEIGHTS:
215            .extended_weights(true)
216          .mr(${MR})
217          .nr(${NR})
218          .kr(${KR})
219          .sr(${SR})
220          .m(m)
221          .n(n)
222          .k(${KBLOCK * 2})
223          .iterations(1)
224          .Test(${", ".join(TEST_ARGS)});
225      }
226    }
227  }
228
229$if KBLOCK > 1:
230  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
231    $if ISA_CHECK:
232      ${ISA_CHECK};
233    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
234      GemmMicrokernelTester()
235        $if EXTENDED_WEIGHTS:
236          .extended_weights(true)
237        .mr(${MR})
238        .nr(${NR})
239        .kr(${KR})
240        .sr(${SR})
241        .m(${MR})
242        .n(${NR})
243        .k(k)
244        .Test(${", ".join(TEST_ARGS)});
245    }
246  }
247
248  $if UKERNEL_TYPE != "IGEMM":
249    TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) {
250      $if ISA_CHECK:
251        ${ISA_CHECK};
252      for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
253        GemmMicrokernelTester()
254          $if EXTENDED_WEIGHTS:
255            .extended_weights(true)
256          .mr(${MR})
257          .nr(${NR})
258          .kr(${KR})
259          .sr(${SR})
260          .m(${MR})
261          .n(${NR})
262          .k(k)
263          .a_stride(${next_prime(ADJKBLOCK + 1)})
264          .Test(${", ".join(TEST_ARGS)});
265      }
266    }
267
268  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
269    $if ISA_CHECK:
270      ${ISA_CHECK};
271    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
272      for (uint32_t n = 1; n <= ${NR}; n++) {
273        for (uint32_t m = 1; m <= ${MR}; m++) {
274          GemmMicrokernelTester()
275            $if EXTENDED_WEIGHTS:
276              .extended_weights(true)
277            .mr(${MR})
278            .nr(${NR})
279            .kr(${KR})
280            .sr(${SR})
281            .m(m)
282            .n(n)
283            .k(k)
284            .iterations(1)
285            .Test(${", ".join(TEST_ARGS)});
286        }
287      }
288    }
289  }
290
291TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
292  $if ISA_CHECK:
293    ${ISA_CHECK};
294  for (size_t k = ${ADJKBLOCK + 1}; k < ${ADJKBLOCK * 10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
295    GemmMicrokernelTester()
296      $if EXTENDED_WEIGHTS:
297        .extended_weights(true)
298      .mr(${MR})
299      .nr(${NR})
300      .kr(${KR})
301      .sr(${SR})
302      .m(${MR})
303      .n(${NR})
304      .k(k)
305      .Test(${", ".join(TEST_ARGS)});
306  }
307}
308
309$if UKERNEL_TYPE.startswith("GEMM"):
310  TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_strided_a) {
311    $if ISA_CHECK:
312      ${ISA_CHECK};
313    for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
314      GemmMicrokernelTester()
315        $if EXTENDED_WEIGHTS:
316          .extended_weights(true)
317        .mr(${MR})
318        .nr(${NR})
319        .kr(${KR})
320        .sr(${SR})
321        .m(${MR})
322        .n(${NR})
323        .k(k)
324        .a_stride(${next_prime(10 if ADJKBLOCK == 1 else ADJKBLOCK * 2 + 1)})
325        .Test(${", ".join(TEST_ARGS)});
326    }
327  }
328
329TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_subtile) {
330  $if ISA_CHECK:
331    ${ISA_CHECK};
332  for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
333    for (uint32_t n = 1; n <= ${NR}; n++) {
334      for (uint32_t m = 1; m <= ${MR}; m++) {
335        GemmMicrokernelTester()
336          $if EXTENDED_WEIGHTS:
337            .extended_weights(true)
338          .mr(${MR})
339          .nr(${NR})
340          .kr(${KR})
341          .sr(${SR})
342          .m(m)
343          .n(n)
344          .k(k)
345          .iterations(1)
346          .Test(${", ".join(TEST_ARGS)});
347      }
348    }
349  }
350}
351
352$if KBLOCK > 1:
353  TEST(${TEST_NAME}, k_div_${KBLOCK}) {
354    $if ISA_CHECK:
355      ${ISA_CHECK};
356    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
357      GemmMicrokernelTester()
358        $if EXTENDED_WEIGHTS:
359          .extended_weights(true)
360        .mr(${MR})
361        .nr(${NR})
362        .kr(${KR})
363        .sr(${SR})
364        .m(${MR})
365        .n(${NR})
366        .k(k)
367        .Test(${", ".join(TEST_ARGS)});
368    }
369  }
370
371  $if UKERNEL_TYPE.startswith("GEMM"):
372    TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) {
373      $if ISA_CHECK:
374        ${ISA_CHECK};
375      for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
376        GemmMicrokernelTester()
377          $if EXTENDED_WEIGHTS:
378            .extended_weights(true)
379          .mr(${MR})
380          .nr(${NR})
381          .kr(${KR})
382          .sr(${SR})
383          .m(${MR})
384          .n(${NR})
385          .k(k)
386          .a_stride(${next_prime(KBLOCK * 10 + 1)})
387          .Test(${", ".join(TEST_ARGS)});
388      }
389    }
390
391  TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
392    $if ISA_CHECK:
393      ${ISA_CHECK};
394    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
395      for (uint32_t n = 1; n <= ${NR}; n++) {
396        for (uint32_t m = 1; m <= ${MR}; m++) {
397          GemmMicrokernelTester()
398            $if EXTENDED_WEIGHTS:
399              .extended_weights(true)
400            .mr(${MR})
401            .nr(${NR})
402            .kr(${KR})
403            .sr(${SR})
404            .m(m)
405            .n(n)
406            .k(k)
407            .iterations(1)
408            .Test(${", ".join(TEST_ARGS)});
409        }
410      }
411    }
412  }
413
414TEST(${TEST_NAME}, n_gt_${NR}) {
415  $if ISA_CHECK:
416    ${ISA_CHECK};
417  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
418    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
419      GemmMicrokernelTester()
420        $if EXTENDED_WEIGHTS:
421          .extended_weights(true)
422        .mr(${MR})
423        .nr(${NR})
424        .kr(${KR})
425        .sr(${SR})
426        .m(${MR})
427        .n(n)
428        .k(k)
429        .Test(${", ".join(TEST_ARGS)});
430    }
431  }
432}
433
434TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) {
435  $if ISA_CHECK:
436    ${ISA_CHECK};
437  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
438    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
439      GemmMicrokernelTester()
440        $if EXTENDED_WEIGHTS:
441          .extended_weights(true)
442        .mr(${MR})
443        .nr(${NR})
444        .kr(${KR})
445        .sr(${SR})
446        .m(${MR})
447        .n(n)
448        .k(k)
449        .cn_stride(${next_prime(NR + 1)})
450        .Test(${", ".join(TEST_ARGS)});
451    }
452  }
453}
454
455$if UKERNEL_TYPE != "IGEMM":
456  TEST(${TEST_NAME}, n_gt_${NR}_strided_a) {
457    $if ISA_CHECK:
458      ${ISA_CHECK};
459    for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
460      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
461        GemmMicrokernelTester()
462          $if EXTENDED_WEIGHTS:
463            .extended_weights(true)
464          .mr(${MR})
465          .nr(${NR})
466          .kr(${KR})
467          .sr(${SR})
468          .m(${MR})
469          .n(n)
470          .k(k)
471          .a_stride(${next_prime(KBLOCK * 5 + 1)})
472          .Test(${", ".join(TEST_ARGS)});
473      }
474    }
475  }
476
477TEST(${TEST_NAME}, n_gt_${NR}_subtile) {
478  $if ISA_CHECK:
479    ${ISA_CHECK};
480  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
481    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
482      for (uint32_t m = 1; m <= ${MR}; m++) {
483        GemmMicrokernelTester()
484          $if EXTENDED_WEIGHTS:
485            .extended_weights(true)
486          .mr(${MR})
487          .nr(${NR})
488          .kr(${KR})
489          .sr(${SR})
490          .m(m)
491          .n(n)
492          .k(k)
493          .iterations(1)
494          .Test(${", ".join(TEST_ARGS)});
495      }
496    }
497  }
498}
499
500TEST(${TEST_NAME}, n_div_${NR}) {
501  $if ISA_CHECK:
502    ${ISA_CHECK};
503  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
504    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
505      GemmMicrokernelTester()
506        $if EXTENDED_WEIGHTS:
507          .extended_weights(true)
508        .mr(${MR})
509        .nr(${NR})
510        .kr(${KR})
511        .sr(${SR})
512        .m(${MR})
513        .n(n)
514        .k(k)
515        .Test(${", ".join(TEST_ARGS)});
516    }
517  }
518}
519
520TEST(${TEST_NAME}, n_div_${NR}_strided_cn) {
521  $if ISA_CHECK:
522    ${ISA_CHECK};
523  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
524    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
525      GemmMicrokernelTester()
526        $if EXTENDED_WEIGHTS:
527          .extended_weights(true)
528        .mr(${MR})
529        .nr(${NR})
530        .kr(${KR})
531        .sr(${SR})
532        .m(${MR})
533        .n(n)
534        .k(k)
535        .cn_stride(${next_prime(NR + 1)})
536        .Test(${", ".join(TEST_ARGS)});
537    }
538  }
539}
540
541$if UKERNEL_TYPE != "IGEMM":
542  TEST(${TEST_NAME}, n_div_${NR}_strided_a) {
543    $if ISA_CHECK:
544      ${ISA_CHECK};
545    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
546      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
547        GemmMicrokernelTester()
548          $if EXTENDED_WEIGHTS:
549            .extended_weights(true)
550          .mr(${MR})
551          .nr(${NR})
552          .kr(${KR})
553          .sr(${SR})
554          .m(${MR})
555          .n(n)
556          .k(k)
557          .a_stride(${next_prime(KBLOCK * 5 + 1)})
558          .Test(${", ".join(TEST_ARGS)});
559      }
560    }
561  }
562
563TEST(${TEST_NAME}, n_div_${NR}_subtile) {
564  $if ISA_CHECK:
565    ${ISA_CHECK};
566  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
567    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
568      for (uint32_t m = 1; m <= ${MR}; m++) {
569        GemmMicrokernelTester()
570          $if EXTENDED_WEIGHTS:
571            .extended_weights(true)
572          .mr(${MR})
573          .nr(${NR})
574          .kr(${KR})
575          .sr(${SR})
576          .m(m)
577          .n(n)
578          .k(k)
579          .iterations(1)
580          .Test(${", ".join(TEST_ARGS)});
581      }
582    }
583  }
584}
585
586$if UKERNEL_TYPE.startswith("IGEMM"):
587  TEST(${TEST_NAME}, small_kernel) {
588    $if ISA_CHECK:
589      ${ISA_CHECK};
590    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
591      GemmMicrokernelTester()
592        $if EXTENDED_WEIGHTS:
593          .extended_weights(true)
594        .mr(${MR})
595        .nr(${NR})
596        .kr(${KR})
597        .sr(${SR})
598        .m(${MR})
599        .n(${NR})
600        .k(k)
601        .ks(3)
602        .Test(${", ".join(TEST_ARGS)});
603    }
604  }
605
606  TEST(${TEST_NAME}, small_kernel_subtile) {
607    $if ISA_CHECK:
608      ${ISA_CHECK};
609    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
610      for (uint32_t n = 1; n <= ${NR}; n++) {
611        for (uint32_t m = 1; m <= ${MR}; m++) {
612          GemmMicrokernelTester()
613            $if EXTENDED_WEIGHTS:
614              .extended_weights(true)
615            .mr(${MR})
616            .nr(${NR})
617            .kr(${KR})
618            .sr(${SR})
619            .m(m)
620            .n(n)
621            .k(k)
622            .ks(3)
623            .iterations(1)
624            .Test(${", ".join(TEST_ARGS)});
625        }
626      }
627    }
628  }
629
630  TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) {
631    $if ISA_CHECK:
632      ${ISA_CHECK};
633    for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
634      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
635        GemmMicrokernelTester()
636          $if EXTENDED_WEIGHTS:
637            .extended_weights(true)
638          .mr(${MR})
639          .nr(${NR})
640          .kr(${KR})
641          .sr(${SR})
642          .m(${MR})
643          .n(n)
644          .k(k)
645          .ks(3)
646          .Test(${", ".join(TEST_ARGS)});
647      }
648    }
649  }
650
651  TEST(${TEST_NAME}, n_div_${NR}_small_kernel) {
652    $if ISA_CHECK:
653      ${ISA_CHECK};
654    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
655      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
656        GemmMicrokernelTester()
657          $if EXTENDED_WEIGHTS:
658            .extended_weights(true)
659          .mr(${MR})
660          .nr(${NR})
661          .kr(${KR})
662          .sr(${SR})
663          .m(${MR})
664          .n(n)
665          .k(k)
666          .ks(3)
667          .Test(${", ".join(TEST_ARGS)});
668      }
669    }
670  }
671
672TEST(${TEST_NAME}, strided_cm_subtile) {
673  $if ISA_CHECK:
674    ${ISA_CHECK};
675  for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
676    for (uint32_t n = 1; n <= ${NR}; n++) {
677      for (uint32_t m = 1; m <= ${MR}; m++) {
678        GemmMicrokernelTester()
679          $if EXTENDED_WEIGHTS:
680            .extended_weights(true)
681          .mr(${MR})
682          .nr(${NR})
683          .kr(${KR})
684          .sr(${SR})
685          .m(m)
686          .n(n)
687          .k(k)
688          .cm_stride(${next_prime(NR + 1)})
689          .iterations(1)
690          .Test(${", ".join(TEST_ARGS)});
691      }
692    }
693  }
694}
695
696$if UKERNEL_TYPE.startswith("IGEMM"):
697  TEST(${TEST_NAME}, a_offset) {
698    $if ISA_CHECK:
699      ${ISA_CHECK};
700    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
701      GemmMicrokernelTester()
702        $if EXTENDED_WEIGHTS:
703          .extended_weights(true)
704        .mr(${MR})
705        .nr(${NR})
706        .kr(${KR})
707        .sr(${SR})
708        .m(${MR})
709        .n(${NR})
710        .k(k)
711        .ks(3)
712        .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
713        .Test(${", ".join(TEST_ARGS)});
714    }
715  }
716
717  TEST(${TEST_NAME}, zero) {
718    $if ISA_CHECK:
719      ${ISA_CHECK};
720    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
721      for (uint32_t mz = 0; mz < ${MR}; mz++) {
722        GemmMicrokernelTester()
723          $if EXTENDED_WEIGHTS:
724            .extended_weights(true)
725          .mr(${MR})
726          .nr(${NR})
727          .kr(${KR})
728          .sr(${SR})
729          .m(${MR})
730          .n(${NR})
731          .k(k)
732          .ks(3)
733          .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
734          .zero_index(mz)
735          .Test(${", ".join(TEST_ARGS)});
736      }
737    }
738  }
739
740$if ACTIVATION == "MINMAX":
741  TEST(${TEST_NAME}, qmin) {
742    $if ISA_CHECK:
743      ${ISA_CHECK};
744    GemmMicrokernelTester()
745      $if EXTENDED_WEIGHTS:
746        .extended_weights(true)
747      .mr(${MR})
748      .nr(${NR})
749      .kr(${KR})
750      .sr(${SR})
751      .m(${MR})
752      .n(${NR})
753      .k(${KBLOCK})
754      .qmin(128)
755      .Test(${", ".join(TEST_ARGS)});
756  }
757
758  TEST(${TEST_NAME}, qmax) {
759    $if ISA_CHECK:
760      ${ISA_CHECK};
761    GemmMicrokernelTester()
762      $if EXTENDED_WEIGHTS:
763        .extended_weights(true)
764      .mr(${MR})
765      .nr(${NR})
766      .kr(${KR})
767      .sr(${SR})
768      .m(${MR})
769      .n(${NR})
770      .k(${KBLOCK})
771      .qmax(128)
772      .Test(${", ".join(TEST_ARGS)});
773  }
774
775TEST(${TEST_NAME}, strided_cm) {
776  $if ISA_CHECK:
777    ${ISA_CHECK};
778  GemmMicrokernelTester()
779    $if EXTENDED_WEIGHTS:
780      .extended_weights(true)
781    .mr(${MR})
782    .nr(${NR})
783    .kr(${KR})
784    .sr(${SR})
785    .m(${MR})
786    .n(${NR})
787    .k(${KBLOCK})
788    .cm_stride(${next_prime(NR + 1)})
789    .Test(${", ".join(TEST_ARGS)});
790}
791
792$if DATATYPE == "qu8":
793  TEST(${TEST_NAME}, no_a_zero_point) {
794    $if ISA_CHECK:
795      ${ISA_CHECK};
796    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
797      GemmMicrokernelTester()
798        $if EXTENDED_WEIGHTS:
799          .extended_weights(true)
800        .mr(${MR})
801        .nr(${NR})
802        .kr(${KR})
803        .sr(${SR})
804        .m(${MR})
805        .n(${NR})
806        .k(k)
807        .a_zero_point(0)
808        .Test(${", ".join(TEST_ARGS)});
809    }
810  }
811
812  TEST(${TEST_NAME}, no_b_zero_point) {
813    $if ISA_CHECK:
814      ${ISA_CHECK};
815    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
816      GemmMicrokernelTester()
817        $if EXTENDED_WEIGHTS:
818          .extended_weights(true)
819        .mr(${MR})
820        .nr(${NR})
821        .kr(${KR})
822        .sr(${SR})
823        .m(${MR})
824        .n(${NR})
825        .k(k)
826        .b_zero_point(0)
827        .Test(${", ".join(TEST_ARGS)});
828    }
829  }
830
831  TEST(${TEST_NAME}, no_zero_point) {
832    $if ISA_CHECK:
833      ${ISA_CHECK};
834    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
835      GemmMicrokernelTester()
836        $if EXTENDED_WEIGHTS:
837          .extended_weights(true)
838        .mr(${MR})
839        .nr(${NR})
840        .kr(${KR})
841        .sr(${SR})
842        .m(${MR})
843        .n(${NR})
844        .k(k)
845        .a_zero_point(0)
846        .b_zero_point(0)
847        .Test(${", ".join(TEST_ARGS)});
848    }
849  }
850
851$if TEST_NAME.startswith('GENERATE') and 'UPTO' in TEST_NAME:
852  TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m_upto_mr) {
853    $if ISA_CHECK:
854      ${ISA_CHECK};
855    for (uint32_t max_mr = 1; max_mr <= ${MR}; max_mr++) {
856      for (uint32_t m = 1; m <= max_mr; m++) {
857        GemmMicrokernelTester()
858          $if EXTENDED_WEIGHTS:
859            .extended_weights(true)
860          .mr(max_mr)
861          .nr(${NR})
862          .kr(${KR})
863          .sr(${SR})
864          .m(m)
865          .n(${NR})
866          .k(${KBLOCK})
867          .iterations(1)
868          .Test(${", ".join(TEST_ARGS)});
869      }
870    }
871  }
872"""
873
874
875def generate_test_cases(ukernel, mr, nr, kr, sr, xw, k_block, init_fn,
876                        requantization, is_pipelined, isa, jit):
877  """Generates all tests cases for a GEMM micro-kernel.
878
879  Args:
880    ukernel: C name of the micro-kernel function.
881    mr: MR parameter of the GEMM micro-kernel.
882    nr: NR parameter of the GEMM micro-kernel.
883    kr: KR parameter of the GEMM micro-kernel.
884    sr: SR parameter of the GEMM micro-kernel.
885    xw: boolean indicator for microkernel with extended weights.
886    k_block: Number of K values processed per one iteration of the main loop of
887      the micro-kernel.
888    init_fn: C name of the function to initialize microkernel parameters.
889    requantization: name of the requantization scheme used by the microkernel.
890    is_pipelined: Indicates if the micro-kernel is implemented with software
891      pipelining. Additional test cases are generated for software pipelined
892      micro-kernels to separately test prologue + epiloque of the pipelined loop
893      and iteration of the pipelined loop.
894    isa: instruction set required to run the micro-kernel. Generated unit test
895      will skip execution if the host processor doesn't support this ISA.
896    jit: if we are generating test code for JIT codegen.
897
898  Returns:
899    Code for the test case.
900  """
901  _, ukernel_name = ukernel.split("_", 1)
902
903  if jit:
904    _, _, datatype, ukernel_type, _ = ukernel.split("_", 4)
905    activation = None
906  else:
907    _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4)
908
909  if activation == "ukernel":
910    activation = "linear"
911  test_args = [ukernel]
912  if init_fn:
913    test_args.append(init_fn)
914    if requantization:
915      requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype)
916      test_args.append("xnn_%s_requantize_%s" % \
917        (requantization_datatype, requantization))
918
919  if jit:
920    if "minmax" in init_fn:
921      activation = "minmax"
922
923  return xngen.preprocess(
924      GEMM_TEST_CODE, {
925          "TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""),
926          "TEST_ARGS": test_args,
927          "UKERNEL_TYPE": ukernel_type.upper(),
928          "DATATYPE": datatype,
929          "ACTIVATION": activation.upper(),
930          "MR": mr,
931          "NR": nr,
932          "KR": kr,
933          "SR": sr,
934          "EXTENDED_WEIGHTS": xw,
935          "KBLOCK": k_block,
936          "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
937          "IS_PIPELINED": is_pipelined,
938          "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
939          "next_prime": next_prime,
940      })
941
942
943def main(args):
944  options = parser.parse_args(args)
945  num_output_files = len(options.output)
946
947  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
948    spec_yaml = yaml.safe_load(spec_file)
949    if not isinstance(spec_yaml, list):
950      raise ValueError("expected a list of micro-kernels in the spec")
951
952    tests = """\
953// Copyright (c) Facebook, Inc. and its affiliates.
954// All rights reserved.
955//
956// Copyright 2019 Google LLC
957//
958// This source code is licensed under the BSD-style license found in the
959// LICENSE file in the root directory of this source tree.
960//
961// Auto-generated file. Do not edit!
962//   Specification: {specification}
963//   Generator: {generator}
964
965
966#include <gtest/gtest.h>
967
968#include <xnnpack/allocator.h>
969#include <xnnpack/common.h>
970#include <xnnpack/isa-checks.h>
971#include <xnnpack/microparams-init.h>
972
973#include <xnnpack/gemm.h>
974#include <xnnpack/igemm.h>
975#include <xnnpack/ppmm.h>
976#include "gemm-microkernel-tester.h"
977""".format(
978    specification=options.spec, generator=sys.argv[0])
979
980    outputs = collections.defaultdict(lambda: tests)
981
982    for ukernel_spec in spec_yaml:
983      name = ukernel_spec["name"]
984      k_block = int(ukernel_spec["k-block"])
985      init_fn = ukernel_spec.get("init")
986      pipelined = bool(ukernel_spec.get("pipelined", False))
987      assembly = bool(ukernel_spec.get("assembly", False))
988      jit = name.startswith("xnn_generate")
989      mr, nr, kr, sr, xw, requantization, arch, isa = split_ukernel_name(name)
990
991      # specification can override architecture
992      arch = ukernel_spec.get("arch", arch)
993
994      test_case = generate_test_cases(name, mr, nr, kr, sr, xw, k_block,
995                                      init_fn, requantization, pipelined, isa,
996                                      jit)
997
998      # Hash the name of each microkernel and figure out which output file to
999      # write it to.
1000      output_index = zlib.crc32(bytes(name, 'utf-8')) % num_output_files
1001      outputs[options.output[output_index]] += "\n\n" + xnncommon.postprocess_test_case(
1002          test_case, arch, isa, assembly, jit)
1003
1004    for output_name in options.output:
1005      txt_changed = True
1006      if os.path.exists(output_name):
1007        with codecs.open(output_name, "r", encoding="utf-8") as output_file:
1008          txt_changed = output_file.read() != outputs[output_name]
1009
1010      if txt_changed:
1011        with codecs.open(output_name, "w", encoding="utf-8") as output_file:
1012          output_file.write(outputs[output_name])
1013
1014
1015if __name__ == "__main__":
1016  main(sys.argv[1:])
1017