1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import bisect 9import codecs 10import collections 11import os 12import sys 13import yaml 14import zlib 15 16sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 17from primes import next_prime 18import xngen 19import xnncommon 20 21parser = argparse.ArgumentParser(description="XNNPACK generator") 22parser.add_argument( 23 "-s", "--spec", metavar="FILE", required=True, help="Spec (YAML) file") 24parser.add_argument( 25 "-o", 26 "--output", 27 action="append", 28 metavar="FILE", 29 required=True, 30 help="Output (C++ source) file(s)") 31parser.set_defaults(defines=list()) 32 33 34def split_ukernel_name(name): 35 common_name, target_name = name.split("__", 1) 36 common_parts = common_name.split("_") 37 xw = "gemm_xw_" in common_name 38 param_spec = common_parts[-1] 39 if param_spec.startswith('upto'): 40 param_spec = param_spec[len('upto'):] 41 if "s" in param_spec: 42 param_spec, sr = param_spec.split("s", 1) 43 sr = int(sr) 44 else: 45 sr = 1 46 if "c" in param_spec: 47 param_spec, kr = param_spec.split("c", 1) 48 kr = int(kr) 49 else: 50 kr = 1 51 mr, nr = map(int, param_spec.split("x")) 52 arch, isa = xnncommon.parse_target_name(target_name) 53 54 requantization = common_parts[-3] 55 if requantization not in ["fp32", "rndnu"]: 56 requantization = None 57 58 return mr, nr, kr, sr, xw, requantization, arch, isa 59 60 61GEMM_TEST_CODE = """\ 62TEST(${TEST_NAME}, k_eq_${KBLOCK}) { 63 $if ISA_CHECK: 64 ${ISA_CHECK}; 65 GemmMicrokernelTester() 66 $if EXTENDED_WEIGHTS: 67 .extended_weights(true) 68 .mr(${MR}) 69 .nr(${NR}) 70 .kr(${KR}) 71 .sr(${SR}) 72 .m(${MR}) 73 .n(${NR}) 74 .k(${KBLOCK}) 75 .Test(${", ".join(TEST_ARGS)}); 76} 77 78TEST(${TEST_NAME}, strided_cn) { 79 $if ISA_CHECK: 80 ${ISA_CHECK}; 81 GemmMicrokernelTester() 82 $if EXTENDED_WEIGHTS: 83 .extended_weights(true) 84 .mr(${MR}) 85 .nr(${NR}) 86 .kr(${KR}) 87 .sr(${SR}) 88 .m(${MR}) 89 .n(${NR}) 90 .k(${KBLOCK}) 91 .cn_stride(${next_prime(NR + 1)}) 92 .Test(${", ".join(TEST_ARGS)}); 93} 94 95$if UKERNEL_TYPE != "IGEMM": 96 TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) { 97 $if ISA_CHECK: 98 ${ISA_CHECK}; 99 GemmMicrokernelTester() 100 $if EXTENDED_WEIGHTS: 101 .extended_weights(true) 102 .mr(${MR}) 103 .nr(${NR}) 104 .kr(${KR}) 105 .sr(${SR}) 106 .m(${MR}) 107 .n(${NR}) 108 .k(${KBLOCK}) 109 .a_stride(${next_prime(KBLOCK + 1)}) 110 .Test(${", ".join(TEST_ARGS)}); 111 } 112 113TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) { 114 $if ISA_CHECK: 115 ${ISA_CHECK}; 116 for (uint32_t n = 1; n <= ${NR}; n++) { 117 for (uint32_t m = 1; m <= ${MR}; m++) { 118 GemmMicrokernelTester() 119 $if EXTENDED_WEIGHTS: 120 .extended_weights(true) 121 .mr(${MR}) 122 .nr(${NR}) 123 .kr(${KR}) 124 .sr(${SR}) 125 .m(m) 126 .n(n) 127 .k(${KBLOCK}) 128 .iterations(1) 129 .Test(${", ".join(TEST_ARGS)}); 130 } 131 } 132} 133 134TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) { 135 $if ISA_CHECK: 136 ${ISA_CHECK}; 137 for (uint32_t m = 1; m <= ${MR}; m++) { 138 GemmMicrokernelTester() 139 $if EXTENDED_WEIGHTS: 140 .extended_weights(true) 141 .mr(${MR}) 142 .nr(${NR}) 143 .kr(${KR}) 144 .sr(${SR}) 145 .m(m) 146 .n(${NR}) 147 .k(${KBLOCK}) 148 .iterations(1) 149 .Test(${", ".join(TEST_ARGS)}); 150 } 151} 152 153 154TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) { 155 $if ISA_CHECK: 156 ${ISA_CHECK}; 157 for (uint32_t n = 1; n <= ${NR}; n++) { 158 GemmMicrokernelTester() 159 $if EXTENDED_WEIGHTS: 160 .extended_weights(true) 161 .mr(${MR}) 162 .nr(${NR}) 163 .kr(${KR}) 164 .sr(${SR}) 165 .m(${MR}) 166 .n(n) 167 .k(${KBLOCK}) 168 .iterations(1) 169 .Test(${", ".join(TEST_ARGS)}); 170 } 171} 172 173$if IS_PIPELINED: 174 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) { 175 $if ISA_CHECK: 176 ${ISA_CHECK}; 177 GemmMicrokernelTester() 178 $if EXTENDED_WEIGHTS: 179 .extended_weights(true) 180 .mr(${MR}) 181 .nr(${NR}) 182 .kr(${KR}) 183 .sr(${SR}) 184 .m(${MR}) 185 .n(${NR}) 186 .k(${KBLOCK * 2}) 187 .Test(${", ".join(TEST_ARGS)}); 188 } 189 190 $if UKERNEL_TYPE != "IGEMM": 191 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) { 192 $if ISA_CHECK: 193 ${ISA_CHECK}; 194 GemmMicrokernelTester() 195 $if EXTENDED_WEIGHTS: 196 .extended_weights(true) 197 .mr(${MR}) 198 .nr(${NR}) 199 .kr(${KR}) 200 .sr(${SR}) 201 .m(${MR}) 202 .n(${NR}) 203 .k(${KBLOCK * 2}) 204 .a_stride(${next_prime(KBLOCK * 2 + 1)}) 205 .Test(${", ".join(TEST_ARGS)}); 206 } 207 208 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) { 209 $if ISA_CHECK: 210 ${ISA_CHECK}; 211 for (uint32_t n = 1; n <= ${NR}; n++) { 212 for (uint32_t m = 1; m <= ${MR}; m++) { 213 GemmMicrokernelTester() 214 $if EXTENDED_WEIGHTS: 215 .extended_weights(true) 216 .mr(${MR}) 217 .nr(${NR}) 218 .kr(${KR}) 219 .sr(${SR}) 220 .m(m) 221 .n(n) 222 .k(${KBLOCK * 2}) 223 .iterations(1) 224 .Test(${", ".join(TEST_ARGS)}); 225 } 226 } 227 } 228 229$if KBLOCK > 1: 230 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) { 231 $if ISA_CHECK: 232 ${ISA_CHECK}; 233 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 234 GemmMicrokernelTester() 235 $if EXTENDED_WEIGHTS: 236 .extended_weights(true) 237 .mr(${MR}) 238 .nr(${NR}) 239 .kr(${KR}) 240 .sr(${SR}) 241 .m(${MR}) 242 .n(${NR}) 243 .k(k) 244 .Test(${", ".join(TEST_ARGS)}); 245 } 246 } 247 248 $if UKERNEL_TYPE != "IGEMM": 249 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) { 250 $if ISA_CHECK: 251 ${ISA_CHECK}; 252 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 253 GemmMicrokernelTester() 254 $if EXTENDED_WEIGHTS: 255 .extended_weights(true) 256 .mr(${MR}) 257 .nr(${NR}) 258 .kr(${KR}) 259 .sr(${SR}) 260 .m(${MR}) 261 .n(${NR}) 262 .k(k) 263 .a_stride(${next_prime(ADJKBLOCK + 1)}) 264 .Test(${", ".join(TEST_ARGS)}); 265 } 266 } 267 268 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) { 269 $if ISA_CHECK: 270 ${ISA_CHECK}; 271 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 272 for (uint32_t n = 1; n <= ${NR}; n++) { 273 for (uint32_t m = 1; m <= ${MR}; m++) { 274 GemmMicrokernelTester() 275 $if EXTENDED_WEIGHTS: 276 .extended_weights(true) 277 .mr(${MR}) 278 .nr(${NR}) 279 .kr(${KR}) 280 .sr(${SR}) 281 .m(m) 282 .n(n) 283 .k(k) 284 .iterations(1) 285 .Test(${", ".join(TEST_ARGS)}); 286 } 287 } 288 } 289 } 290 291TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) { 292 $if ISA_CHECK: 293 ${ISA_CHECK}; 294 for (size_t k = ${ADJKBLOCK + 1}; k < ${ADJKBLOCK * 10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) { 295 GemmMicrokernelTester() 296 $if EXTENDED_WEIGHTS: 297 .extended_weights(true) 298 .mr(${MR}) 299 .nr(${NR}) 300 .kr(${KR}) 301 .sr(${SR}) 302 .m(${MR}) 303 .n(${NR}) 304 .k(k) 305 .Test(${", ".join(TEST_ARGS)}); 306 } 307} 308 309$if UKERNEL_TYPE.startswith("GEMM"): 310 TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_strided_a) { 311 $if ISA_CHECK: 312 ${ISA_CHECK}; 313 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) { 314 GemmMicrokernelTester() 315 $if EXTENDED_WEIGHTS: 316 .extended_weights(true) 317 .mr(${MR}) 318 .nr(${NR}) 319 .kr(${KR}) 320 .sr(${SR}) 321 .m(${MR}) 322 .n(${NR}) 323 .k(k) 324 .a_stride(${next_prime(10 if ADJKBLOCK == 1 else ADJKBLOCK * 2 + 1)}) 325 .Test(${", ".join(TEST_ARGS)}); 326 } 327 } 328 329TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_subtile) { 330 $if ISA_CHECK: 331 ${ISA_CHECK}; 332 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) { 333 for (uint32_t n = 1; n <= ${NR}; n++) { 334 for (uint32_t m = 1; m <= ${MR}; m++) { 335 GemmMicrokernelTester() 336 $if EXTENDED_WEIGHTS: 337 .extended_weights(true) 338 .mr(${MR}) 339 .nr(${NR}) 340 .kr(${KR}) 341 .sr(${SR}) 342 .m(m) 343 .n(n) 344 .k(k) 345 .iterations(1) 346 .Test(${", ".join(TEST_ARGS)}); 347 } 348 } 349 } 350} 351 352$if KBLOCK > 1: 353 TEST(${TEST_NAME}, k_div_${KBLOCK}) { 354 $if ISA_CHECK: 355 ${ISA_CHECK}; 356 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 357 GemmMicrokernelTester() 358 $if EXTENDED_WEIGHTS: 359 .extended_weights(true) 360 .mr(${MR}) 361 .nr(${NR}) 362 .kr(${KR}) 363 .sr(${SR}) 364 .m(${MR}) 365 .n(${NR}) 366 .k(k) 367 .Test(${", ".join(TEST_ARGS)}); 368 } 369 } 370 371 $if UKERNEL_TYPE.startswith("GEMM"): 372 TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) { 373 $if ISA_CHECK: 374 ${ISA_CHECK}; 375 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 376 GemmMicrokernelTester() 377 $if EXTENDED_WEIGHTS: 378 .extended_weights(true) 379 .mr(${MR}) 380 .nr(${NR}) 381 .kr(${KR}) 382 .sr(${SR}) 383 .m(${MR}) 384 .n(${NR}) 385 .k(k) 386 .a_stride(${next_prime(KBLOCK * 10 + 1)}) 387 .Test(${", ".join(TEST_ARGS)}); 388 } 389 } 390 391 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) { 392 $if ISA_CHECK: 393 ${ISA_CHECK}; 394 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 395 for (uint32_t n = 1; n <= ${NR}; n++) { 396 for (uint32_t m = 1; m <= ${MR}; m++) { 397 GemmMicrokernelTester() 398 $if EXTENDED_WEIGHTS: 399 .extended_weights(true) 400 .mr(${MR}) 401 .nr(${NR}) 402 .kr(${KR}) 403 .sr(${SR}) 404 .m(m) 405 .n(n) 406 .k(k) 407 .iterations(1) 408 .Test(${", ".join(TEST_ARGS)}); 409 } 410 } 411 } 412 } 413 414TEST(${TEST_NAME}, n_gt_${NR}) { 415 $if ISA_CHECK: 416 ${ISA_CHECK}; 417 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 418 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 419 GemmMicrokernelTester() 420 $if EXTENDED_WEIGHTS: 421 .extended_weights(true) 422 .mr(${MR}) 423 .nr(${NR}) 424 .kr(${KR}) 425 .sr(${SR}) 426 .m(${MR}) 427 .n(n) 428 .k(k) 429 .Test(${", ".join(TEST_ARGS)}); 430 } 431 } 432} 433 434TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) { 435 $if ISA_CHECK: 436 ${ISA_CHECK}; 437 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 438 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 439 GemmMicrokernelTester() 440 $if EXTENDED_WEIGHTS: 441 .extended_weights(true) 442 .mr(${MR}) 443 .nr(${NR}) 444 .kr(${KR}) 445 .sr(${SR}) 446 .m(${MR}) 447 .n(n) 448 .k(k) 449 .cn_stride(${next_prime(NR + 1)}) 450 .Test(${", ".join(TEST_ARGS)}); 451 } 452 } 453} 454 455$if UKERNEL_TYPE != "IGEMM": 456 TEST(${TEST_NAME}, n_gt_${NR}_strided_a) { 457 $if ISA_CHECK: 458 ${ISA_CHECK}; 459 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 460 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 461 GemmMicrokernelTester() 462 $if EXTENDED_WEIGHTS: 463 .extended_weights(true) 464 .mr(${MR}) 465 .nr(${NR}) 466 .kr(${KR}) 467 .sr(${SR}) 468 .m(${MR}) 469 .n(n) 470 .k(k) 471 .a_stride(${next_prime(KBLOCK * 5 + 1)}) 472 .Test(${", ".join(TEST_ARGS)}); 473 } 474 } 475 } 476 477TEST(${TEST_NAME}, n_gt_${NR}_subtile) { 478 $if ISA_CHECK: 479 ${ISA_CHECK}; 480 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 481 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 482 for (uint32_t m = 1; m <= ${MR}; m++) { 483 GemmMicrokernelTester() 484 $if EXTENDED_WEIGHTS: 485 .extended_weights(true) 486 .mr(${MR}) 487 .nr(${NR}) 488 .kr(${KR}) 489 .sr(${SR}) 490 .m(m) 491 .n(n) 492 .k(k) 493 .iterations(1) 494 .Test(${", ".join(TEST_ARGS)}); 495 } 496 } 497 } 498} 499 500TEST(${TEST_NAME}, n_div_${NR}) { 501 $if ISA_CHECK: 502 ${ISA_CHECK}; 503 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 504 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 505 GemmMicrokernelTester() 506 $if EXTENDED_WEIGHTS: 507 .extended_weights(true) 508 .mr(${MR}) 509 .nr(${NR}) 510 .kr(${KR}) 511 .sr(${SR}) 512 .m(${MR}) 513 .n(n) 514 .k(k) 515 .Test(${", ".join(TEST_ARGS)}); 516 } 517 } 518} 519 520TEST(${TEST_NAME}, n_div_${NR}_strided_cn) { 521 $if ISA_CHECK: 522 ${ISA_CHECK}; 523 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 524 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 525 GemmMicrokernelTester() 526 $if EXTENDED_WEIGHTS: 527 .extended_weights(true) 528 .mr(${MR}) 529 .nr(${NR}) 530 .kr(${KR}) 531 .sr(${SR}) 532 .m(${MR}) 533 .n(n) 534 .k(k) 535 .cn_stride(${next_prime(NR + 1)}) 536 .Test(${", ".join(TEST_ARGS)}); 537 } 538 } 539} 540 541$if UKERNEL_TYPE != "IGEMM": 542 TEST(${TEST_NAME}, n_div_${NR}_strided_a) { 543 $if ISA_CHECK: 544 ${ISA_CHECK}; 545 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 546 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 547 GemmMicrokernelTester() 548 $if EXTENDED_WEIGHTS: 549 .extended_weights(true) 550 .mr(${MR}) 551 .nr(${NR}) 552 .kr(${KR}) 553 .sr(${SR}) 554 .m(${MR}) 555 .n(n) 556 .k(k) 557 .a_stride(${next_prime(KBLOCK * 5 + 1)}) 558 .Test(${", ".join(TEST_ARGS)}); 559 } 560 } 561 } 562 563TEST(${TEST_NAME}, n_div_${NR}_subtile) { 564 $if ISA_CHECK: 565 ${ISA_CHECK}; 566 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 567 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 568 for (uint32_t m = 1; m <= ${MR}; m++) { 569 GemmMicrokernelTester() 570 $if EXTENDED_WEIGHTS: 571 .extended_weights(true) 572 .mr(${MR}) 573 .nr(${NR}) 574 .kr(${KR}) 575 .sr(${SR}) 576 .m(m) 577 .n(n) 578 .k(k) 579 .iterations(1) 580 .Test(${", ".join(TEST_ARGS)}); 581 } 582 } 583 } 584} 585 586$if UKERNEL_TYPE.startswith("IGEMM"): 587 TEST(${TEST_NAME}, small_kernel) { 588 $if ISA_CHECK: 589 ${ISA_CHECK}; 590 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 591 GemmMicrokernelTester() 592 $if EXTENDED_WEIGHTS: 593 .extended_weights(true) 594 .mr(${MR}) 595 .nr(${NR}) 596 .kr(${KR}) 597 .sr(${SR}) 598 .m(${MR}) 599 .n(${NR}) 600 .k(k) 601 .ks(3) 602 .Test(${", ".join(TEST_ARGS)}); 603 } 604 } 605 606 TEST(${TEST_NAME}, small_kernel_subtile) { 607 $if ISA_CHECK: 608 ${ISA_CHECK}; 609 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 610 for (uint32_t n = 1; n <= ${NR}; n++) { 611 for (uint32_t m = 1; m <= ${MR}; m++) { 612 GemmMicrokernelTester() 613 $if EXTENDED_WEIGHTS: 614 .extended_weights(true) 615 .mr(${MR}) 616 .nr(${NR}) 617 .kr(${KR}) 618 .sr(${SR}) 619 .m(m) 620 .n(n) 621 .k(k) 622 .ks(3) 623 .iterations(1) 624 .Test(${", ".join(TEST_ARGS)}); 625 } 626 } 627 } 628 } 629 630 TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) { 631 $if ISA_CHECK: 632 ${ISA_CHECK}; 633 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 634 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 635 GemmMicrokernelTester() 636 $if EXTENDED_WEIGHTS: 637 .extended_weights(true) 638 .mr(${MR}) 639 .nr(${NR}) 640 .kr(${KR}) 641 .sr(${SR}) 642 .m(${MR}) 643 .n(n) 644 .k(k) 645 .ks(3) 646 .Test(${", ".join(TEST_ARGS)}); 647 } 648 } 649 } 650 651 TEST(${TEST_NAME}, n_div_${NR}_small_kernel) { 652 $if ISA_CHECK: 653 ${ISA_CHECK}; 654 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 655 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 656 GemmMicrokernelTester() 657 $if EXTENDED_WEIGHTS: 658 .extended_weights(true) 659 .mr(${MR}) 660 .nr(${NR}) 661 .kr(${KR}) 662 .sr(${SR}) 663 .m(${MR}) 664 .n(n) 665 .k(k) 666 .ks(3) 667 .Test(${", ".join(TEST_ARGS)}); 668 } 669 } 670 } 671 672TEST(${TEST_NAME}, strided_cm_subtile) { 673 $if ISA_CHECK: 674 ${ISA_CHECK}; 675 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 676 for (uint32_t n = 1; n <= ${NR}; n++) { 677 for (uint32_t m = 1; m <= ${MR}; m++) { 678 GemmMicrokernelTester() 679 $if EXTENDED_WEIGHTS: 680 .extended_weights(true) 681 .mr(${MR}) 682 .nr(${NR}) 683 .kr(${KR}) 684 .sr(${SR}) 685 .m(m) 686 .n(n) 687 .k(k) 688 .cm_stride(${next_prime(NR + 1)}) 689 .iterations(1) 690 .Test(${", ".join(TEST_ARGS)}); 691 } 692 } 693 } 694} 695 696$if UKERNEL_TYPE.startswith("IGEMM"): 697 TEST(${TEST_NAME}, a_offset) { 698 $if ISA_CHECK: 699 ${ISA_CHECK}; 700 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 701 GemmMicrokernelTester() 702 $if EXTENDED_WEIGHTS: 703 .extended_weights(true) 704 .mr(${MR}) 705 .nr(${NR}) 706 .kr(${KR}) 707 .sr(${SR}) 708 .m(${MR}) 709 .n(${NR}) 710 .k(k) 711 .ks(3) 712 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)}) 713 .Test(${", ".join(TEST_ARGS)}); 714 } 715 } 716 717 TEST(${TEST_NAME}, zero) { 718 $if ISA_CHECK: 719 ${ISA_CHECK}; 720 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 721 for (uint32_t mz = 0; mz < ${MR}; mz++) { 722 GemmMicrokernelTester() 723 $if EXTENDED_WEIGHTS: 724 .extended_weights(true) 725 .mr(${MR}) 726 .nr(${NR}) 727 .kr(${KR}) 728 .sr(${SR}) 729 .m(${MR}) 730 .n(${NR}) 731 .k(k) 732 .ks(3) 733 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)}) 734 .zero_index(mz) 735 .Test(${", ".join(TEST_ARGS)}); 736 } 737 } 738 } 739 740$if ACTIVATION == "MINMAX": 741 TEST(${TEST_NAME}, qmin) { 742 $if ISA_CHECK: 743 ${ISA_CHECK}; 744 GemmMicrokernelTester() 745 $if EXTENDED_WEIGHTS: 746 .extended_weights(true) 747 .mr(${MR}) 748 .nr(${NR}) 749 .kr(${KR}) 750 .sr(${SR}) 751 .m(${MR}) 752 .n(${NR}) 753 .k(${KBLOCK}) 754 .qmin(128) 755 .Test(${", ".join(TEST_ARGS)}); 756 } 757 758 TEST(${TEST_NAME}, qmax) { 759 $if ISA_CHECK: 760 ${ISA_CHECK}; 761 GemmMicrokernelTester() 762 $if EXTENDED_WEIGHTS: 763 .extended_weights(true) 764 .mr(${MR}) 765 .nr(${NR}) 766 .kr(${KR}) 767 .sr(${SR}) 768 .m(${MR}) 769 .n(${NR}) 770 .k(${KBLOCK}) 771 .qmax(128) 772 .Test(${", ".join(TEST_ARGS)}); 773 } 774 775TEST(${TEST_NAME}, strided_cm) { 776 $if ISA_CHECK: 777 ${ISA_CHECK}; 778 GemmMicrokernelTester() 779 $if EXTENDED_WEIGHTS: 780 .extended_weights(true) 781 .mr(${MR}) 782 .nr(${NR}) 783 .kr(${KR}) 784 .sr(${SR}) 785 .m(${MR}) 786 .n(${NR}) 787 .k(${KBLOCK}) 788 .cm_stride(${next_prime(NR + 1)}) 789 .Test(${", ".join(TEST_ARGS)}); 790} 791 792$if DATATYPE == "qu8": 793 TEST(${TEST_NAME}, no_a_zero_point) { 794 $if ISA_CHECK: 795 ${ISA_CHECK}; 796 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 797 GemmMicrokernelTester() 798 $if EXTENDED_WEIGHTS: 799 .extended_weights(true) 800 .mr(${MR}) 801 .nr(${NR}) 802 .kr(${KR}) 803 .sr(${SR}) 804 .m(${MR}) 805 .n(${NR}) 806 .k(k) 807 .a_zero_point(0) 808 .Test(${", ".join(TEST_ARGS)}); 809 } 810 } 811 812 TEST(${TEST_NAME}, no_b_zero_point) { 813 $if ISA_CHECK: 814 ${ISA_CHECK}; 815 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 816 GemmMicrokernelTester() 817 $if EXTENDED_WEIGHTS: 818 .extended_weights(true) 819 .mr(${MR}) 820 .nr(${NR}) 821 .kr(${KR}) 822 .sr(${SR}) 823 .m(${MR}) 824 .n(${NR}) 825 .k(k) 826 .b_zero_point(0) 827 .Test(${", ".join(TEST_ARGS)}); 828 } 829 } 830 831 TEST(${TEST_NAME}, no_zero_point) { 832 $if ISA_CHECK: 833 ${ISA_CHECK}; 834 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 835 GemmMicrokernelTester() 836 $if EXTENDED_WEIGHTS: 837 .extended_weights(true) 838 .mr(${MR}) 839 .nr(${NR}) 840 .kr(${KR}) 841 .sr(${SR}) 842 .m(${MR}) 843 .n(${NR}) 844 .k(k) 845 .a_zero_point(0) 846 .b_zero_point(0) 847 .Test(${", ".join(TEST_ARGS)}); 848 } 849 } 850 851$if TEST_NAME.startswith('GENERATE') and 'UPTO' in TEST_NAME: 852 TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m_upto_mr) { 853 $if ISA_CHECK: 854 ${ISA_CHECK}; 855 for (uint32_t max_mr = 1; max_mr <= ${MR}; max_mr++) { 856 for (uint32_t m = 1; m <= max_mr; m++) { 857 GemmMicrokernelTester() 858 $if EXTENDED_WEIGHTS: 859 .extended_weights(true) 860 .mr(max_mr) 861 .nr(${NR}) 862 .kr(${KR}) 863 .sr(${SR}) 864 .m(m) 865 .n(${NR}) 866 .k(${KBLOCK}) 867 .iterations(1) 868 .Test(${", ".join(TEST_ARGS)}); 869 } 870 } 871 } 872""" 873 874 875def generate_test_cases(ukernel, mr, nr, kr, sr, xw, k_block, init_fn, 876 requantization, is_pipelined, isa, jit): 877 """Generates all tests cases for a GEMM micro-kernel. 878 879 Args: 880 ukernel: C name of the micro-kernel function. 881 mr: MR parameter of the GEMM micro-kernel. 882 nr: NR parameter of the GEMM micro-kernel. 883 kr: KR parameter of the GEMM micro-kernel. 884 sr: SR parameter of the GEMM micro-kernel. 885 xw: boolean indicator for microkernel with extended weights. 886 k_block: Number of K values processed per one iteration of the main loop of 887 the micro-kernel. 888 init_fn: C name of the function to initialize microkernel parameters. 889 requantization: name of the requantization scheme used by the microkernel. 890 is_pipelined: Indicates if the micro-kernel is implemented with software 891 pipelining. Additional test cases are generated for software pipelined 892 micro-kernels to separately test prologue + epiloque of the pipelined loop 893 and iteration of the pipelined loop. 894 isa: instruction set required to run the micro-kernel. Generated unit test 895 will skip execution if the host processor doesn't support this ISA. 896 jit: if we are generating test code for JIT codegen. 897 898 Returns: 899 Code for the test case. 900 """ 901 _, ukernel_name = ukernel.split("_", 1) 902 903 if jit: 904 _, _, datatype, ukernel_type, _ = ukernel.split("_", 4) 905 activation = None 906 else: 907 _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4) 908 909 if activation == "ukernel": 910 activation = "linear" 911 test_args = [ukernel] 912 if init_fn: 913 test_args.append(init_fn) 914 if requantization: 915 requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype) 916 test_args.append("xnn_%s_requantize_%s" % \ 917 (requantization_datatype, requantization)) 918 919 if jit: 920 if "minmax" in init_fn: 921 activation = "minmax" 922 923 return xngen.preprocess( 924 GEMM_TEST_CODE, { 925 "TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""), 926 "TEST_ARGS": test_args, 927 "UKERNEL_TYPE": ukernel_type.upper(), 928 "DATATYPE": datatype, 929 "ACTIVATION": activation.upper(), 930 "MR": mr, 931 "NR": nr, 932 "KR": kr, 933 "SR": sr, 934 "EXTENDED_WEIGHTS": xw, 935 "KBLOCK": k_block, 936 "ADJKBLOCK": 2 * k_block if is_pipelined else k_block, 937 "IS_PIPELINED": is_pipelined, 938 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 939 "next_prime": next_prime, 940 }) 941 942 943def main(args): 944 options = parser.parse_args(args) 945 num_output_files = len(options.output) 946 947 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 948 spec_yaml = yaml.safe_load(spec_file) 949 if not isinstance(spec_yaml, list): 950 raise ValueError("expected a list of micro-kernels in the spec") 951 952 tests = """\ 953// Copyright (c) Facebook, Inc. and its affiliates. 954// All rights reserved. 955// 956// Copyright 2019 Google LLC 957// 958// This source code is licensed under the BSD-style license found in the 959// LICENSE file in the root directory of this source tree. 960// 961// Auto-generated file. Do not edit! 962// Specification: {specification} 963// Generator: {generator} 964 965 966#include <gtest/gtest.h> 967 968#include <xnnpack/allocator.h> 969#include <xnnpack/common.h> 970#include <xnnpack/isa-checks.h> 971#include <xnnpack/microparams-init.h> 972 973#include <xnnpack/gemm.h> 974#include <xnnpack/igemm.h> 975#include <xnnpack/ppmm.h> 976#include "gemm-microkernel-tester.h" 977""".format( 978 specification=options.spec, generator=sys.argv[0]) 979 980 outputs = collections.defaultdict(lambda: tests) 981 982 for ukernel_spec in spec_yaml: 983 name = ukernel_spec["name"] 984 k_block = int(ukernel_spec["k-block"]) 985 init_fn = ukernel_spec.get("init") 986 pipelined = bool(ukernel_spec.get("pipelined", False)) 987 assembly = bool(ukernel_spec.get("assembly", False)) 988 jit = name.startswith("xnn_generate") 989 mr, nr, kr, sr, xw, requantization, arch, isa = split_ukernel_name(name) 990 991 # specification can override architecture 992 arch = ukernel_spec.get("arch", arch) 993 994 test_case = generate_test_cases(name, mr, nr, kr, sr, xw, k_block, 995 init_fn, requantization, pipelined, isa, 996 jit) 997 998 # Hash the name of each microkernel and figure out which output file to 999 # write it to. 1000 output_index = zlib.crc32(bytes(name, 'utf-8')) % num_output_files 1001 outputs[options.output[output_index]] += "\n\n" + xnncommon.postprocess_test_case( 1002 test_case, arch, isa, assembly, jit) 1003 1004 for output_name in options.output: 1005 txt_changed = True 1006 if os.path.exists(output_name): 1007 with codecs.open(output_name, "r", encoding="utf-8") as output_file: 1008 txt_changed = output_file.read() != outputs[output_name] 1009 1010 if txt_changed: 1011 with codecs.open(output_name, "w", encoding="utf-8") as output_file: 1012 output_file.write(outputs[output_name]) 1013 1014 1015if __name__ == "__main__": 1016 main(sys.argv[1:]) 1017