1# mypy: allow-untyped-defs 2 3 4import argparse 5import sys 6 7 8sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1} 9 10 11def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets): 12 def compute(regid, InType, use_weights, isa, prefetch): 13 code = [] 14 15 if InType == "float": 16 code.append( 17 " vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);" # noqa 18 % (regid, regid, regid) 19 ) 20 elif InType == "at::Half": 21 code.append( 22 " vop%d = _mm256_fmadd_ps(\n" 23 " vwgt,\n" 24 " _mm256_cvtph_ps(\n" 25 " _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n" # noqa 26 " vop%d);" % (regid, regid, regid) 27 ) 28 elif InType == "at::BFloat16": 29 code.append( 30 " vop%d = _mm256_fmadd_ps(\n" 31 " vwgt,\n" 32 " _mm256_castsi256_ps(_mm256_slli_epi32(\n" 33 " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" 34 " reinterpret_cast<const __m128i*>(ip + (%d)))),\n" 35 " 16)),\n" # noqa 36 " vop%d);" % (regid, regid, regid) 37 ) 38 elif InType == "uint8_t": 39 code.append( 40 " vop%d = _mm256_fmadd_ps(\n" 41 " vwgt,\n" 42 " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n" 43 " _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))),\n" # noqa 44 " _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid) 45 ) 46 else: 47 assert False 48 49 if prefetch: 50 code.append( 51 " _mm_prefetch(\n" 52 " reinterpret_cast<const char*>(&ip_next_T0[%d]), _MM_HINT_T0);" 53 % (regid) 54 ) 55 else: 56 code.append( 57 " // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid) 58 ) 59 60 return code 61 62 code = [] 63 code.append(" // unrolling " + str(uf) + " times") 64 65 if use_offsets: 66 code.append( 67 " for (" 68 + IndexType 69 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" 70 ) 71 else: 72 code.append( 73 " for (" 74 + IndexType 75 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" 76 ) 77 78 code.append(" " + OutType + "* op = &out[rangeIndex * block_size];") 79 for i in range(0, uf): 80 j = 8 * i 81 code.append(" __m256 vop" + str(j) + " = _mm256_setzero_ps();") 82 83 # inner loop 84 if use_offsets: 85 code.append( 86 " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n" 87 + " return false;\n" 88 + " }" 89 ) 90 code.append("""\ 91 int64_t end_offset = offsets[rangeIndex + 1]; 92 int64_t length = end_offset - offsets[rangeIndex];""") 93 code.append( 94 " for (" 95 + "int64_t" 96 + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa 97 ) 98 else: 99 code.append( 100 " if (dataInd + lengths[rangeIndex] > index_size) {\n" 101 + " return false;\n" 102 + " }" 103 ) 104 code.append( 105 " for (" 106 + IndexType 107 + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa 108 ) 109 code.append(" const " + IndexType + " idx = indices[dataInd];") 110 code.append( 111 " if (idx < 0 || idx >= data_size) {\n" 112 + " return false;\n" 113 + " }" 114 ) 115 116 if InType == "uint8_t": 117 code.append(" " + OutType + " wgt = 1.f;") 118 code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") 119 code.append(" " + OutType + " bio;") 120 code.append(" if (weights) {") 121 code.append( 122 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa 123 ) 124 code.append(" }") 125 if fused: 126 code.append( 127 " const float* scale_bias = reinterpret_cast<const float*>(\n" 128 " &input[idx * fused_block_size + block_size]);" 129 ) 130 code.append(" bio = wgt * scale_bias[1];") 131 code.append(" wgt = wgt * scale_bias[0];") 132 else: 133 code.append(" bio = wgt * scale_bias[2 * idx + 1];") 134 code.append(" wgt = wgt * scale_bias[2 * idx];") 135 code.append(" __m256 vbio = _mm256_set1_ps(bio);") 136 else: 137 code.append(" " + OutType + " wgt = 1.f;") 138 code.append(" if (weights) {") 139 code.append( 140 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa 141 ) 142 code.append(" }") 143 code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") 144 145 code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType)) 146 code.append( 147 " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n" 148 " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" 149 " ? (dataInd + prefdist_T0)\n" 150 " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" 151 " : dataInd;".format( 152 IndexType 153 ) 154 ) 155 code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];") 156 code.append( 157 " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n" 158 + " return false;\n" 159 + " }" 160 ) 161 162 code.append( 163 " const {}* ip_next_T0 = " 164 "&input[idx_pref_T0 * fused_block_size];".format(InType) 165 ) 166 167 for i in range(0, uf): 168 j = 8 * i 169 cachelinesize = 64 170 byteoffset = sizeof[InType] * j 171 prefetch = (byteoffset % cachelinesize) == 0 172 code.extend(compute(j, InType, use_weights, isa, prefetch)) 173 code.append(" }") 174 175 if use_offsets: 176 code.append(" if (!normalize_by_lengths || length == 0) {") 177 else: 178 code.append(" if (!normalize_by_lengths || lengths[rangeIndex] == 0) {") 179 for i in range(0, uf): 180 j = 8 * i 181 code.append(" _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");") 182 code.append(" } else {") 183 # inv of length 184 if use_offsets: 185 code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);") 186 else: 187 code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);") 188 for i in range(0, uf): 189 j = 8 * i 190 code.append( 191 " _mm256_storeu_ps(&op[" 192 + str(j) 193 + "], _mm256_mul_ps(" 194 + "vop" 195 + str(j) 196 + ", vlen_inv));" 197 ) 198 code.append(" }") 199 200 code.append(" }") 201 return code 202 203 204def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): 205 def compute(InType, use_weights, isa): 206 code = [] 207 if InType == "float": 208 code.append( 209 " _mm256_storeu_ps(\n" 210 " &op[j],\n" 211 " _mm256_fmadd_ps(\n" 212 " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa 213 ) 214 elif InType == "at::Half": 215 code.append( 216 " _mm256_storeu_ps(\n" 217 " &op[j],\n" 218 " _mm256_fmadd_ps(\n" 219 " vwgt,\n" 220 " _mm256_cvtph_ps(_mm_loadu_si128(\n" 221 " reinterpret_cast<const __m128i*>(&ip[j]))),\n" 222 " _mm256_loadu_ps(&op[j])));" 223 ) 224 elif InType == "at::BFloat16": 225 code.append( 226 " _mm256_storeu_ps(\n" 227 " &op[j],\n" 228 " _mm256_fmadd_ps(\n" 229 " vwgt,\n" 230 " _mm256_castsi256_ps(_mm256_slli_epi32(\n" 231 " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" 232 " reinterpret_cast<const __m128i*>(&ip[j]))),\n" 233 " 16)),\n" 234 " _mm256_loadu_ps(&op[j])));" 235 ) 236 elif InType == "uint8_t": 237 code.append( 238 " _mm256_storeu_ps(\n" 239 " &op[j],\n" 240 " _mm256_fmadd_ps(\n" 241 " vwgt,\n" 242 " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa 243 " reinterpret_cast<const __m128i*>(&ip[j])))),\n" 244 " _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));" 245 ) 246 else: 247 assert False 248 249 code.append( 250 " _mm_prefetch(\n" 251 " reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);" 252 ) 253 254 return code 255 256 code = [] 257 if InType == "at::Half": 258 code.append(" alignas(64) at::Half vtmp1[8] = {0};") 259 if InType == "at::BFloat16": 260 code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};") 261 262 263 if use_offsets: 264 code.append( 265 " for (" 266 + IndexType 267 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" 268 ) 269 else: 270 code.append( 271 " for (" 272 + IndexType 273 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" 274 ) 275 276 code.append(" " + OutType + "* op = &out[rangeIndex * block_size];") 277 278 # initialize to 0 279 code.append(" int64_t j = 0;") 280 code.append(" for (; j + 8 <= block_size; j += 8) {") 281 code.append(" _mm256_storeu_ps(op + j, _mm256_setzero_ps());") 282 code.append(" }") 283 code.append(" for (; j < block_size; j++) {") 284 code.append(" op[j] = 0.0f;") 285 code.append(" }") 286 287 # inner loop 288 if use_offsets: 289 code.append( 290 " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n" 291 + " return false;\n" 292 + " }" 293 ) 294 code.append("""\ 295 int64_t end_offset = offsets[rangeIndex + 1]; 296 int64_t length = end_offset - offsets[rangeIndex];""") 297 code.append( 298 " for (" 299 + "int64_t" 300 + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa 301 ) 302 else: 303 code.append( 304 " if (dataInd + lengths[rangeIndex] > index_size) {\n" 305 + " return false;\n" 306 + " }" 307 ) 308 code.append( 309 " for (" 310 + IndexType 311 + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa 312 ) 313 code.append(" const " + IndexType + " idx = indices[dataInd];") 314 code.append( 315 " if (idx < 0 || idx >= data_size) {\n" 316 + " return false;\n" 317 + " }" 318 ) 319 320 if InType == "uint8_t": 321 code.append(" " + OutType + " wgt = 1.f;") 322 code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") 323 code.append(" " + OutType + " bio;") 324 code.append(" if (weights) {") 325 code.append( 326 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa 327 ) 328 code.append(" }") 329 if fused: 330 code.append( 331 " const float* scale_bias = reinterpret_cast<const float*>(\n" 332 " &input[idx * fused_block_size + block_size]);" 333 ) 334 code.append(" bio = wgt * scale_bias[1];") 335 code.append(" wgt = wgt * scale_bias[0];") 336 else: 337 code.append(" bio = wgt * scale_bias[2 * idx + 1];") 338 code.append(" wgt = wgt * scale_bias[2 * idx];") 339 code.append(" __m256 vbio = _mm256_set1_ps(bio);") 340 else: 341 code.append(" " + OutType + " wgt = 1.f;") 342 code.append(" if (weights) {") 343 code.append( 344 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa 345 ) 346 code.append(" }") 347 code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") 348 349 code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType)) 350 code.append( 351 " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n" 352 " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" 353 " ? (dataInd + prefdist_T0)\n" 354 " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" 355 " : dataInd;".format( 356 IndexType 357 ) 358 ) 359 code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];") 360 code.append( 361 " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n" 362 + " return false;\n" 363 + " }" 364 ) 365 code.append( 366 " const {}* ip_next_T0 = " 367 "&input[idx_pref_T0 * fused_block_size];".format(InType) 368 ) 369 370 # compute and store main loop 371 code.append(" j = 0;") 372 code.append(" for (; j + 8 <= block_size; j += 8) {") 373 code.extend(compute(InType, use_weights, isa)) 374 code.append(" }") 375 # leftover 376 code.append(" for (; j < block_size; j++) {") 377 if InType == "float": 378 code.append(" op[j] = std::fma(wgt, ip[j], op[j]);") 379 elif InType == "at::Half": 380 code.append(" vtmp1[0] = ip[j];") 381 code.append( 382 " __m256 vtmp2 =\n" 383 " _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));" 384 ) 385 code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);") 386 elif InType == "at::BFloat16": 387 code.append(" vtmp1[0] = ip[j];") 388 code.append( 389 " __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n" 390 " _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),\n" 391 " 16));" 392 ) 393 code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);") 394 elif InType == "uint8_t": 395 code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);") 396 else: 397 assert False 398 399 code.append(" }") 400 401 code.append(" }") 402 403 if use_offsets: 404 code.append(" if (normalize_by_lengths && length) {") 405 code.append(" float len_inv = 1.0f / length;") 406 else: 407 code.append(" if (normalize_by_lengths && lengths[rangeIndex]) {") 408 code.append(" float len_inv = 1.0f / lengths[rangeIndex];") 409 code.append(" __m256 vlen_inv = _mm256_set1_ps(len_inv);") 410 code.append(" j = 0;") 411 code.append(" for (; j + 8 <= block_size; j += 8) {") 412 code.append( 413 " _mm256_storeu_ps(\n" 414 " &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));" 415 ) 416 code.append(" }") 417 code.append(" for (; j < block_size; j++) {") 418 code.append(" op[j] = len_inv * op[j];") 419 code.append(" }") 420 421 code.append(" }") 422 423 code.append(" }") 424 return code 425 426 427# start main code 428parser = argparse.ArgumentParser() 429parser.add_argument("-f", "--filename", help="file name") 430parser.add_argument("--fused", action="store_true") 431parser.add_argument("--use-offsets", action="store_true") 432opts = parser.parse_args() 433if opts.filename: 434 filename = opts.filename 435elif opts.fused: 436 if opts.use_offsets: 437 filename = "embedding_lookup_fused_8bit_rowwise_idx_avx2.cc" 438 else: 439 filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc" 440else: 441 if opts.use_offsets: 442 filename = "embedding_lookup_idx_avx2.cc" 443 else: 444 filename = "embedding_lookup_avx2.cc" 445 446options = [ 447 ["int32_t", "int", "float", "float", "float", "float"], 448 ["int64_t", "int64_t", "float", "float", "float", "float"], 449 ["int32_t", "int", "half", "at::Half", "float", "float"], 450 ["int64_t", "int64_t", "half", "at::Half", "float", "float"], 451 ["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"], 452 ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"], 453 ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"], 454 ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], 455] 456 457code = [] 458# includes 459code.append("//// --------------------------") 460code.append("//// ATTENTION:") 461code.append("//// THIS CODE IS AUTOGENERATED") 462code.append("//// BY {}".format(sys.argv[0])) 463code.append("//// DO NOT MODIFY!!!") 464code.append("//// --------------------------\n") 465 466code.append("#include <c10/util/Half.h>") 467code.append("#include <c10/util/BFloat16.h>") 468code.append("#include <immintrin.h>") 469 470code.append("namespace caffe2 {\n") 471for o in options: 472 [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o 473 474 prefix = "Fused8BitRowwise" if opts.fused else "" 475 code.append("template <bool IS_WEIGHT_POSITIONAL>") 476 if opts.use_offsets: 477 fn_base = "{}EmbeddingLookupIdx_{}_{}_{}".format( 478 prefix, IndexTypeName, InTypeName, OutTypeName 479 ) 480 else: 481 fn_base = "{}EmbeddingLookup_{}_{}_{}".format( 482 prefix, IndexTypeName, InTypeName, OutTypeName 483 ) 484 suffix = "__avx2_fma" 485 fn = "static bool " + fn_base + suffix 486 code.append(fn + "(") 487 488 args = [] 489 args.append(" const int64_t block_size,") 490 args.append(" const int64_t output_size,") 491 args.append(" const int64_t index_size,") 492 args.append(" const int64_t data_size,") 493 args.append(" const " + InType + "* input,") 494 args.append(" const " + IndexType + "* indices,") 495 if opts.use_offsets: 496 args.append(" const " + IndexType + "* offsets,") 497 else: 498 args.append(" const int* lengths,") 499 args.append(" const float* weights,") 500 if not opts.fused: 501 args.append(" const float* scale_bias,") 502 args.append(" bool normalize_by_lengths,") 503 args.append(" " + OutType + "* out) {") 504 code += args 505 506 code.append(" const " + IndexType + " prefdist_T0 = 16;") 507 code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)") 508 # block_size is the number of elements and fused_block_size is the size of 509 # an entire row, including scale and bias. 510 offset = (8 // sizeof[InType]) if opts.fused else 0 511 code.append( 512 " const {} fused_block_size = block_size + {};".format(IndexType, offset) 513 ) 514 if opts.use_offsets: 515 code.append(" int64_t dataInd = 0;") 516 else: 517 code.append(" " + IndexType + " dataInd = 0;") 518 519 # code.append("printf(\"calling " + fn + "\\n\");"); 520 521 code.append(" if (block_size == 128) {") 522 code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) 523 code.append(" } else if (block_size == 64) {") 524 code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) 525 code.append(" } else if (block_size == 32) {") 526 code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) 527 code.append(" } else if (block_size == 16) {") 528 code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) 529 code.append(" } else {") 530 code.append(" // generic code") 531 code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)") 532 code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) 533 code.append(" }") 534 code.append(" return dataInd == index_size;") 535 536 code.append("}") 537 538 for is_weight_positional in ["false", "true"]: 539 code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(") 540 code += args 541 # Resolve the Lint warnings: Limit of 80 characters in one line. 542 extra_space = "\n " 543 ret_string = " return " + fn_base + suffix + "<" + is_weight_positional + ">(" 544 if len(ret_string) <= 80: 545 code.append(ret_string) 546 else: 547 code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(") 548 code.append(" block_size,") 549 code.append(" output_size,") 550 code.append(" index_size,") 551 code.append(" data_size,") 552 code.append(" input,") 553 code.append(" indices,") 554 if opts.use_offsets: 555 code.append(" offsets,") 556 else: 557 code.append(" lengths,") 558 code.append(" weights,") 559 if not opts.fused: 560 code.append(" scale_bias,") 561 code.append(" normalize_by_lengths,") 562 code.append(" out);") 563 code.append("}") 564 565 code.append("") 566 567code.append("} // namespace caffe2") 568 569with open(filename, "w") as fout: 570 for c in code: 571 # print(c, file = fout) 572 fout.write(c + "\n") 573 574 575print("Created " + filename) 576