xref: /aosp_15_r20/external/pytorch/caffe2/perfkernels/hp_emblookup_codegen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3
4import argparse
5import sys
6
7
8sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1}
9
10
11def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
12    def compute(regid, InType, use_weights, isa, prefetch):
13        code = []
14
15        if InType == "float":
16            code.append(
17                "        vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);"  # noqa
18                % (regid, regid, regid)
19            )
20        elif InType == "at::Half":
21            code.append(
22                "        vop%d = _mm256_fmadd_ps(\n"
23                "            vwgt,\n"
24                "            _mm256_cvtph_ps(\n"
25                "                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n"  # noqa
26                "            vop%d);" % (regid, regid, regid)
27            )
28        elif InType == "at::BFloat16":
29            code.append(
30                "        vop%d = _mm256_fmadd_ps(\n"
31                "            vwgt,\n"
32                "            _mm256_castsi256_ps(_mm256_slli_epi32(\n"
33                "                _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
34                "                    reinterpret_cast<const __m128i*>(ip + (%d)))),\n"
35                "                16)),\n"  # noqa
36                "            vop%d);" % (regid, regid, regid)
37            )
38        elif InType == "uint8_t":
39            code.append(
40                "        vop%d = _mm256_fmadd_ps(\n"
41                "            vwgt,\n"
42                "            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
43                "                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))),\n"  # noqa
44                "            _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid)
45            )
46        else:
47            assert False
48
49        if prefetch:
50            code.append(
51                "        _mm_prefetch(\n"
52                "            reinterpret_cast<const char*>(&ip_next_T0[%d]), _MM_HINT_T0);"
53                % (regid)
54            )
55        else:
56            code.append(
57                "        // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid)
58            )
59
60        return code
61
62    code = []
63    code.append("    // unrolling " + str(uf) + " times")
64
65    if use_offsets:
66        code.append(
67            "    for ("
68            + IndexType
69            + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
70        )
71    else:
72        code.append(
73            "    for ("
74            + IndexType
75            + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
76        )
77
78    code.append("      " + OutType + "* op = &out[rangeIndex * block_size];")
79    for i in range(0, uf):
80        j = 8 * i
81        code.append("      __m256 vop" + str(j) + " = _mm256_setzero_ps();")
82
83    # inner loop
84    if use_offsets:
85        code.append(
86            "      if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
87            + "        return false;\n"
88            + "      }"
89        )
90        code.append("""\
91      int64_t end_offset = offsets[rangeIndex + 1];
92      int64_t length = end_offset - offsets[rangeIndex];""")
93        code.append(
94            "      for ("
95            + "int64_t"
96            + " start = dataInd; dataInd < end_offset - offsets[0];\n           ++dataInd) {"  # noqa
97        )
98    else:
99        code.append(
100            "      if (dataInd + lengths[rangeIndex] > index_size) {\n"
101            + "        return false;\n"
102            + "      }"
103        )
104        code.append(
105            "      for ("
106            + IndexType
107            + " start = dataInd; dataInd < start + lengths[rangeIndex];\n           ++dataInd) {"  # noqa
108        )
109    code.append("        const " + IndexType + " idx = indices[dataInd];")
110    code.append(
111        "        if (idx < 0 || idx >= data_size) {\n"
112        + "          return false;\n"
113        + "        }"
114    )
115
116    if InType == "uint8_t":
117        code.append("        " + OutType + " wgt = 1.f;")
118        code.append("        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
119        code.append("        " + OutType + " bio;")
120        code.append("        if (weights) {")
121        code.append(
122            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
123        )
124        code.append("        }")
125        if fused:
126            code.append(
127                "        const float* scale_bias = reinterpret_cast<const float*>(\n"
128                "            &input[idx * fused_block_size + block_size]);"
129            )
130            code.append("        bio = wgt * scale_bias[1];")
131            code.append("        wgt = wgt * scale_bias[0];")
132        else:
133            code.append("        bio = wgt * scale_bias[2 * idx + 1];")
134            code.append("        wgt = wgt * scale_bias[2 * idx];")
135        code.append("        __m256 vbio = _mm256_set1_ps(bio);")
136    else:
137        code.append("        " + OutType + " wgt = 1.f;")
138        code.append("        if (weights) {")
139        code.append(
140            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
141        )
142        code.append("        }")
143    code.append("        __m256 vwgt = _mm256_set1_ps(wgt);")
144
145    code.append("        const {}* ip = &input[idx * fused_block_size];".format(InType))
146    code.append(
147        "        const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
148        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
149        "            ? (dataInd + prefdist_T0)\n"
150        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
151        "            : dataInd;".format(
152            IndexType
153        )
154    )
155    code.append("        const " + IndexType + " idx_pref_T0 = indices[next_T0];")
156    code.append(
157        "        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
158        + "          return false;\n"
159        + "        }"
160    )
161
162    code.append(
163        "        const {}* ip_next_T0 = "
164        "&input[idx_pref_T0 * fused_block_size];".format(InType)
165    )
166
167    for i in range(0, uf):
168        j = 8 * i
169        cachelinesize = 64
170        byteoffset = sizeof[InType] * j
171        prefetch = (byteoffset % cachelinesize) == 0
172        code.extend(compute(j, InType, use_weights, isa, prefetch))
173    code.append("      }")
174
175    if use_offsets:
176        code.append("      if (!normalize_by_lengths || length == 0) {")
177    else:
178        code.append("      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {")
179    for i in range(0, uf):
180        j = 8 * i
181        code.append("        _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
182    code.append("      } else {")
183    # inv of length
184    if use_offsets:
185        code.append("        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
186    else:
187        code.append("        __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
188    for i in range(0, uf):
189        j = 8 * i
190        code.append(
191            "        _mm256_storeu_ps(&op["
192            + str(j)
193            + "], _mm256_mul_ps("
194            + "vop"
195            + str(j)
196            + ", vlen_inv));"
197        )
198    code.append("      }")
199
200    code.append("    }")
201    return code
202
203
204def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
205    def compute(InType, use_weights, isa):
206        code = []
207        if InType == "float":
208            code.append(
209                "          _mm256_storeu_ps(\n"
210                "              &op[j],\n"
211                "              _mm256_fmadd_ps(\n"
212                "                  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));"  # noqa
213            )
214        elif InType == "at::Half":
215            code.append(
216                "          _mm256_storeu_ps(\n"
217                "              &op[j],\n"
218                "              _mm256_fmadd_ps(\n"
219                "                  vwgt,\n"
220                "                  _mm256_cvtph_ps(_mm_loadu_si128(\n"
221                "                      reinterpret_cast<const __m128i*>(&ip[j]))),\n"
222                "                  _mm256_loadu_ps(&op[j])));"
223            )
224        elif InType == "at::BFloat16":
225            code.append(
226                "          _mm256_storeu_ps(\n"
227                "              &op[j],\n"
228                "              _mm256_fmadd_ps(\n"
229                "                  vwgt,\n"
230                "                  _mm256_castsi256_ps(_mm256_slli_epi32(\n"
231                "                      _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
232                "                          reinterpret_cast<const __m128i*>(&ip[j]))),\n"
233                "                      16)),\n"
234                "                  _mm256_loadu_ps(&op[j])));"
235            )
236        elif InType == "uint8_t":
237            code.append(
238                "          _mm256_storeu_ps(\n"
239                "              &op[j],\n"
240                "              _mm256_fmadd_ps(\n"
241                "                  vwgt,\n"
242                "                  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n"  # noqa
243                "                      reinterpret_cast<const __m128i*>(&ip[j])))),\n"
244                "                  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
245            )
246        else:
247            assert False
248
249        code.append(
250            "          _mm_prefetch(\n"
251            "              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);"
252        )
253
254        return code
255
256    code = []
257    if InType == "at::Half":
258        code.append("    alignas(64) at::Half vtmp1[8] = {0};")
259    if InType == "at::BFloat16":
260        code.append("    alignas(64) at::BFloat16 vtmp1[8] = {0};")
261
262
263    if use_offsets:
264        code.append(
265            "    for ("
266            + IndexType
267            + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
268        )
269    else:
270        code.append(
271            "    for ("
272            + IndexType
273            + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
274        )
275
276    code.append("      " + OutType + "* op = &out[rangeIndex * block_size];")
277
278    # initialize to 0
279    code.append("      int64_t j = 0;")
280    code.append("      for (; j + 8 <= block_size; j += 8) {")
281    code.append("        _mm256_storeu_ps(op + j, _mm256_setzero_ps());")
282    code.append("      }")
283    code.append("      for (; j < block_size; j++) {")
284    code.append("        op[j] = 0.0f;")
285    code.append("      }")
286
287    # inner loop
288    if use_offsets:
289        code.append(
290            "      if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
291            + "        return false;\n"
292            + "      }"
293        )
294        code.append("""\
295      int64_t end_offset = offsets[rangeIndex + 1];
296      int64_t length = end_offset - offsets[rangeIndex];""")
297        code.append(
298            "      for ("
299            + "int64_t"
300            + " start = dataInd; dataInd < end_offset - offsets[0];\n           ++dataInd) {"  # noqa
301        )
302    else:
303        code.append(
304            "      if (dataInd + lengths[rangeIndex] > index_size) {\n"
305            + "        return false;\n"
306            + "      }"
307        )
308        code.append(
309            "      for ("
310            + IndexType
311            + " start = dataInd; dataInd < start + lengths[rangeIndex];\n           ++dataInd) {"  # noqa
312        )
313    code.append("        const " + IndexType + " idx = indices[dataInd];")
314    code.append(
315        "        if (idx < 0 || idx >= data_size) {\n"
316        + "          return false;\n"
317        + "        }"
318    )
319
320    if InType == "uint8_t":
321        code.append("        " + OutType + " wgt = 1.f;")
322        code.append("        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
323        code.append("        " + OutType + " bio;")
324        code.append("        if (weights) {")
325        code.append(
326            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
327        )
328        code.append("        }")
329        if fused:
330            code.append(
331                "        const float* scale_bias = reinterpret_cast<const float*>(\n"
332                "            &input[idx * fused_block_size + block_size]);"
333            )
334            code.append("        bio = wgt * scale_bias[1];")
335            code.append("        wgt = wgt * scale_bias[0];")
336        else:
337            code.append("        bio = wgt * scale_bias[2 * idx + 1];")
338            code.append("        wgt = wgt * scale_bias[2 * idx];")
339        code.append("        __m256 vbio = _mm256_set1_ps(bio);")
340    else:
341        code.append("        " + OutType + " wgt = 1.f;")
342        code.append("        if (weights) {")
343        code.append(
344            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
345        )
346        code.append("        }")
347    code.append("        __m256 vwgt = _mm256_set1_ps(wgt);")
348
349    code.append("        const {}* ip = &input[idx * fused_block_size];".format(InType))
350    code.append(
351        "        const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
352        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
353        "            ? (dataInd + prefdist_T0)\n"
354        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
355        "            : dataInd;".format(
356            IndexType
357        )
358    )
359    code.append("        const " + IndexType + " idx_pref_T0 = indices[next_T0];")
360    code.append(
361        "        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
362        + "          return false;\n"
363        + "        }"
364    )
365    code.append(
366        "        const {}* ip_next_T0 = "
367        "&input[idx_pref_T0 * fused_block_size];".format(InType)
368    )
369
370    # compute and store main loop
371    code.append("        j = 0;")
372    code.append("        for (; j + 8 <= block_size; j += 8) {")
373    code.extend(compute(InType, use_weights, isa))
374    code.append("        }")
375    # leftover
376    code.append("        for (; j < block_size; j++) {")
377    if InType == "float":
378        code.append("          op[j] = std::fma(wgt, ip[j], op[j]);")
379    elif InType == "at::Half":
380        code.append("          vtmp1[0] = ip[j];")
381        code.append(
382            "          __m256 vtmp2 =\n"
383            "              _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
384        )
385        code.append("          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
386    elif InType == "at::BFloat16":
387        code.append("          vtmp1[0] = ip[j];")
388        code.append(
389            "          __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n"
390            "              _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),\n"
391            "              16));"
392        )
393        code.append("          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
394    elif InType == "uint8_t":
395        code.append("          op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
396    else:
397        assert False
398
399    code.append("        }")
400
401    code.append("      }")
402
403    if use_offsets:
404        code.append("      if (normalize_by_lengths && length) {")
405        code.append("        float len_inv = 1.0f / length;")
406    else:
407        code.append("      if (normalize_by_lengths && lengths[rangeIndex]) {")
408        code.append("        float len_inv = 1.0f / lengths[rangeIndex];")
409    code.append("        __m256 vlen_inv = _mm256_set1_ps(len_inv);")
410    code.append("        j = 0;")
411    code.append("        for (; j + 8 <= block_size; j += 8) {")
412    code.append(
413        "          _mm256_storeu_ps(\n"
414        "              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));"
415    )
416    code.append("        }")
417    code.append("        for (; j < block_size; j++) {")
418    code.append("          op[j] = len_inv * op[j];")
419    code.append("        }")
420
421    code.append("      }")
422
423    code.append("    }")
424    return code
425
426
427# start main code
428parser = argparse.ArgumentParser()
429parser.add_argument("-f", "--filename", help="file name")
430parser.add_argument("--fused", action="store_true")
431parser.add_argument("--use-offsets", action="store_true")
432opts = parser.parse_args()
433if opts.filename:
434    filename = opts.filename
435elif opts.fused:
436    if opts.use_offsets:
437        filename = "embedding_lookup_fused_8bit_rowwise_idx_avx2.cc"
438    else:
439        filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc"
440else:
441    if opts.use_offsets:
442        filename = "embedding_lookup_idx_avx2.cc"
443    else:
444        filename = "embedding_lookup_avx2.cc"
445
446options = [
447    ["int32_t", "int", "float", "float", "float", "float"],
448    ["int64_t", "int64_t", "float", "float", "float", "float"],
449    ["int32_t", "int", "half", "at::Half", "float", "float"],
450    ["int64_t", "int64_t", "half", "at::Half", "float", "float"],
451    ["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"],
452    ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"],
453    ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"],
454    ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
455]
456
457code = []
458# includes
459code.append("//// --------------------------")
460code.append("//// ATTENTION:")
461code.append("//// THIS CODE IS AUTOGENERATED")
462code.append("//// BY {}".format(sys.argv[0]))
463code.append("//// DO NOT MODIFY!!!")
464code.append("//// --------------------------\n")
465
466code.append("#include <c10/util/Half.h>")
467code.append("#include <c10/util/BFloat16.h>")
468code.append("#include <immintrin.h>")
469
470code.append("namespace caffe2 {\n")
471for o in options:
472    [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o
473
474    prefix = "Fused8BitRowwise" if opts.fused else ""
475    code.append("template <bool IS_WEIGHT_POSITIONAL>")
476    if opts.use_offsets:
477        fn_base = "{}EmbeddingLookupIdx_{}_{}_{}".format(
478            prefix, IndexTypeName, InTypeName, OutTypeName
479        )
480    else:
481        fn_base = "{}EmbeddingLookup_{}_{}_{}".format(
482            prefix, IndexTypeName, InTypeName, OutTypeName
483        )
484    suffix = "__avx2_fma"
485    fn = "static bool " + fn_base + suffix
486    code.append(fn + "(")
487
488    args = []
489    args.append("    const int64_t block_size,")
490    args.append("    const int64_t output_size,")
491    args.append("    const int64_t index_size,")
492    args.append("    const int64_t data_size,")
493    args.append("    const " + InType + "* input,")
494    args.append("    const " + IndexType + "* indices,")
495    if opts.use_offsets:
496        args.append("    const " + IndexType + "* offsets,")
497    else:
498        args.append("    const int* lengths,")
499    args.append("    const float* weights,")
500    if not opts.fused:
501        args.append("    const float* scale_bias,")
502    args.append("    bool normalize_by_lengths,")
503    args.append("    " + OutType + "* out) {")
504    code += args
505
506    code.append("  const " + IndexType + " prefdist_T0 = 16;")
507    code.append("  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)")
508    # block_size is the number of elements and fused_block_size is the size of
509    # an entire row, including scale and bias.
510    offset = (8 // sizeof[InType]) if opts.fused else 0
511    code.append(
512        "  const {} fused_block_size = block_size + {};".format(IndexType, offset)
513    )
514    if opts.use_offsets:
515        code.append("  int64_t dataInd = 0;")
516    else:
517        code.append("  " + IndexType + " dataInd = 0;")
518
519    # code.append("printf(\"calling " + fn + "\\n\");");
520
521    code.append("  if (block_size == 128) {")
522    code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
523    code.append("  } else if (block_size == 64) {")
524    code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
525    code.append("  } else if (block_size == 32) {")
526    code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
527    code.append("  } else if (block_size == 16) {")
528    code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
529    code.append("  } else {")
530    code.append("    // generic code")
531    code.append("    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)")
532    code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
533    code.append("  }")
534    code.append("  return dataInd == index_size;")
535
536    code.append("}")
537
538    for is_weight_positional in ["false", "true"]:
539        code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(")
540        code += args
541        # Resolve the Lint warnings: Limit of 80 characters in one line.
542        extra_space = "\n      "
543        ret_string = "  return " + fn_base + suffix + "<" + is_weight_positional + ">("
544        if len(ret_string) <= 80:
545            code.append(ret_string)
546        else:
547            code.append("  return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(")
548        code.append("      block_size,")
549        code.append("      output_size,")
550        code.append("      index_size,")
551        code.append("      data_size,")
552        code.append("      input,")
553        code.append("      indices,")
554        if opts.use_offsets:
555            code.append("      offsets,")
556        else:
557            code.append("      lengths,")
558        code.append("      weights,")
559        if not opts.fused:
560            code.append("      scale_bias,")
561        code.append("      normalize_by_lengths,")
562        code.append("      out);")
563        code.append("}")
564
565    code.append("")
566
567code.append("} // namespace caffe2")
568
569with open(filename, "w") as fout:
570    for c in code:
571        # print(c, file = fout)
572        fout.write(c + "\n")
573
574
575print("Created " + filename)
576