xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/EmbeddingBag.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/TensorSubclassLikeUtils.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/EmbeddingBag.h>
9 
10 #include <ATen/native/CPUBlas.h>
11 #include <ATen/native/NonSymbolicBC.h>
12 
13 #include <c10/util/irange.h>
14 #include <c10/util/Half.h>
15 
16 #ifdef USE_FBGEMM
17 #include <fbgemm/Fbgemm.h>
18 #include <fbgemm/FbgemmConvert.h>
19 #else
20 #include <caffe2/perfkernels/embedding_lookup_idx.h>
21 #endif
22 
23 #include <cstring>
24 #include <tuple>
25 #include <utility>
26 #include <vector>
27 
28 #ifndef AT_PER_OPERATOR_HEADERS
29 #include <ATen/Functions.h>
30 #include <ATen/NativeFunctions.h>
31 #else
32 #include <ATen/ops/_embedding_bag.h>
33 #include <ATen/ops/_embedding_bag_backward_native.h>
34 #include <ATen/ops/_embedding_bag_dense_backward.h>
35 #include <ATen/ops/_embedding_bag_dense_backward_native.h>
36 #include <ATen/ops/_embedding_bag_forward_only.h>
37 #include <ATen/ops/_embedding_bag_forward_only_native.h>
38 #include <ATen/ops/_embedding_bag_native.h>
39 #include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
40 #include <ATen/ops/_embedding_bag_sparse_backward.h>
41 #include <ATen/ops/_embedding_bag_sparse_backward_native.h>
42 #include <ATen/ops/embedding_backward_native.h>
43 #include <ATen/ops/embedding_bag_native.h>
44 #include <ATen/ops/empty.h>
45 #include <ATen/ops/max.h>
46 #include <ATen/ops/ones_like.h>
47 #include <ATen/ops/resize_native.h>
48 #include <ATen/ops/zero_native.h>
49 #include <ATen/ops/zeros.h>
50 #endif
51 
52 namespace at::native {
53 
54 template<typename scalar_t>
55 scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
56 
make_offset2bag(const Tensor & offsets,Tensor & offset2bag)57 static void make_offset2bag(const Tensor &offsets, Tensor& offset2bag) {
58   offset2bag.index_add_(
59       0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1]
60   offset2bag[0] -= 1;                     // offset2bag = [0 0 1 0 1]
61   offset2bag = offset2bag.cumsum(0, offset2bag.scalar_type());     // offset2bag = [0 0 1 1 2]
62 }
63 
64 namespace {
65 
promoteIndicesAndOffsets(const Tensor & indices,const Tensor & offsets)66 std::pair<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> promoteIndicesAndOffsets(
67     const Tensor& indices,
68     const Tensor& offsets) {
69   const auto commonType =
70       promoteTypes(offsets.scalar_type(), indices.scalar_type());
71   return {
72       indices.scalar_type() == commonType ? c10::MaybeOwned<Tensor>::borrowed(indices)
73                                           : c10::MaybeOwned<Tensor>::owned(indices.toType(commonType)),
74       offsets.scalar_type() == commonType ? c10::MaybeOwned<Tensor>::borrowed(offsets)
75                                           : c10::MaybeOwned<Tensor>::owned(offsets.toType(commonType))};
76 }
77 
78 // Determines if we can use a fast implementation for index_select_add, which
79 // is only applicable if special conditions are met
80 template<typename index_t>
is_fast_path_index_select(const Tensor & src,Tensor & output,index_t padding_idx)81 bool is_fast_path_index_select(const Tensor& src, Tensor& output, index_t padding_idx) {
82   return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
83           src.scalar_type() == kBFloat16) &&
84       src.strides()[1] == 1 && output.strides()[1] == 1 &&
85       padding_idx < static_cast<index_t>(0);
86 }
87 
88 // Determines if we can use a fast implementation for index_select_scale_add,
89 // which is only applicable if special conditions are met
90 template<typename index_t>
is_fast_path_index_select_scale(const Tensor & src,const Tensor & scale,Tensor & output,index_t padding_idx)91 bool is_fast_path_index_select_scale(const Tensor& src, const Tensor& scale, Tensor& output, index_t padding_idx) {
92   return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
93           src.scalar_type() == kBFloat16) &&
94       src.strides()[1] == 1 && output.strides()[1] == 1 &&
95       scale.strides()[0] == 1 && padding_idx < static_cast<index_t>(0);
96 }
97 
98 template<typename index_t>
is_fast_path(const Tensor & src,const std::optional<Tensor> & scale,Tensor & output,index_t padding_idx)99 bool is_fast_path(const Tensor& src, const std::optional<Tensor>& scale, Tensor& output, index_t padding_idx) {
100   return (scale.has_value() && scale.value().defined()) ?
101          is_fast_path_index_select_scale(src, scale.value(), output, padding_idx) :
102          is_fast_path_index_select(src, output, padding_idx);
103 }
104 
105 // This function combines index_select (using select_indices as the index) and
106 // index_add (using add_indices as the index), without creating an intermediary
107 // tensor to hold the selected embeddings
108 template <typename data_t, typename index_t>
109 static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
index_select_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & src,Tensor & output,const Tensor &,bool,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache *)110 index_select_add(
111     const Tensor& select_indices,
112     const Tensor& add_indices,
113     const Tensor& src,
114     Tensor& output,
115     const Tensor& /*offsets*/,
116     bool /*include_last_offset*/,
117     Tensor& bag_size,
118     index_t padding_idx,
119     _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
120   TORCH_CHECK(select_indices.numel() == add_indices.numel());
121   auto* add_indices_data = add_indices.const_data_ptr<index_t>();
122   auto* select_indices_data = select_indices.const_data_ptr<index_t>();
123   auto* src_data = src.const_data_ptr<data_t>();
124   auto* output_data = output.data_ptr<data_t>();
125   index_t* bag_size_data = nullptr;
126   if (bag_size.defined()) {
127     bag_size_data = bag_size.data_ptr<index_t>();
128   }
129   auto numel = add_indices.numel();
130   int64_t ddim = src.size(1);
131   auto vocab_size = src.size(0);
132   auto src_stride0 = src.strides()[0];
133   auto src_stride1 = src.strides()[1];
134   auto output_stride0 = output.strides()[0];
135   auto output_stride1 = output.strides()[1];
136 
137   for (const auto i : c10::irange(numel)) {
138     // We can skip indices equal to padding_idx so they are not included in
139     // the reduction
140     auto idx = select_indices_data[i];
141     TORCH_CHECK(
142         idx >= 0 && idx < vocab_size,
143         "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
144         idx);
145     if (idx != padding_idx) {
146       at::native::cpublas::axpy<data_t>(ddim, 1,
147               src_data + src_stride0 * idx, src_stride1,
148               output_data + output_stride0 * add_indices_data[i], output_stride1);
149     } else if (bag_size_data) {
150       // Decrement bag_size to reflect that the index is padded
151       bag_size_data[add_indices_data[i]]--;
152     }
153   }
154 }
155 
156 namespace {
157 template <typename index_t>
fbgemm_spmdm_report_error_(int64_t output_size,int index_size,int64_t N,const index_t * offsets,const index_t * indices)158 void fbgemm_spmdm_report_error_(
159     int64_t output_size,
160     int index_size,
161     int64_t N,
162     const index_t* offsets,
163     const index_t* indices) {
164   for (const auto m : c10::irange(output_size)) {
165     for (index_t i = offsets[m]; i < offsets[m + 1]; ++i) {
166       TORCH_CHECK(i < index_size);
167       index_t idx = indices[i];
168       TORCH_CHECK(
169           0 <= idx && idx < N,
170           "Index ",
171           i,
172           " of input takes value ",
173           idx,
174           " which is not in the valid range [0, ",
175           N,
176           ")");
177     }
178   }
179   TORCH_CHECK(
180       offsets[output_size] == index_size,
181       "Your input appears to be incorrect: the last offset value should be "
182        "the size of the indices tensor, but it seems not to be the case.");
183 }
184 } // namespace
185 
186 template <typename data_t, typename index_t>
187 typename std::enable_if<
188     std::is_same<data_t, at::Half>::value ||
189         std::is_same<data_t, at::BFloat16>::value,
190     void>::type
index_select_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)191 index_select_add(
192     const Tensor& select_indices,
193     const Tensor& add_indices,
194     const Tensor& src,
195     Tensor& output,
196     const Tensor& offsets,
197     bool include_last_offset,
198     Tensor& bag_size,
199     index_t padding_idx,
200     _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
201   int64_t ddim = src.size(1);
202   auto* select_indices_data = select_indices.const_data_ptr<index_t>();
203   auto* output_data = output.data_ptr<data_t>();
204 
205   if (is_fast_path_index_select(src, output, padding_idx)) {
206     auto src_contig = src.contiguous();
207     auto* src_data = src_contig.const_data_ptr<data_t>();
208     int64_t output_size = offsets.numel() - 1;
209     auto* offsets_data = offsets.const_data_ptr<index_t>();
210     std::vector<index_t> offsets_include_last;
211 
212     if (include_last_offset) {
213       output_size = offsets.numel() - 1;
214     } else {
215       output_size = offsets.numel();
216       offsets_include_last.resize(offsets.numel() + 1);
217       if (offsets.numel() > 0) {
218         std::memcpy(
219             offsets_include_last.data(),
220             offsets.const_data_ptr<index_t>(),
221             sizeof(index_t) * offsets.numel());
222       }
223       offsets_include_last[offsets.numel()] = select_indices.numel();
224       offsets_data = offsets_include_last.data();
225     }
226 #if defined(USE_FBGEMM)
227     constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
228     auto kernel_16bit_index_t = fbgemm_kernel_cache
229         ? fbgemm_kernel_cache
230               ->getCallback</* has_weight */ false, index_t, uint16_t>(ddim)
231         : fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
232               /* block_size */ ddim,
233               /* has_weight */ false,
234               /* normalize_by_lengths */ false,
235               /* prefetch */ 16,
236               /* is_weight_positional */ false,
237               /* use_offsets */ true,
238               /* is_bf16_out */ isbf16,
239               /* is_bf16_in */ isbf16);
240     at::parallel_for(
241         0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
242           bool success = kernel_16bit_index_t(
243               /* output_size */ end_idx - start_idx,
244               /* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
245               /* data_size */ src.size(0),
246               /* input */ reinterpret_cast<const uint16_t*>(src_data),
247               /* indices */ select_indices_data + offsets_data[start_idx],
248               /* offsets_or_lengths */ offsets_data + start_idx,
249               /* weights */ nullptr,
250               /* output */
251               reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
252           if (!success) {
253             fbgemm_spmdm_report_error_(
254                 end_idx - start_idx,
255                 offsets_data[end_idx] - offsets_data[start_idx],
256                 src.size(0),
257                 offsets_data + start_idx,
258                 select_indices_data + offsets_data[start_idx]);
259           }
260         });
261 #else
262     // Initialize the intermediate output buffer to be 0.
263     Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
264     auto* output_data_fp32 = output_fp32.data_ptr<float>();
265     using bVec = vec::Vectorized<BFloat16>;
266     using fVec = vec::Vectorized<float>;
267     at::parallel_for(
268         0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
269           caffe2::EmbeddingLookupIdx(
270               /*block_size=*/ddim,
271               /*output_size=*/end_idx - start_idx,
272               /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
273               /*data_size=*/src.size(0),
274               /*input=*/src_data,
275               /*indices=*/select_indices_data + offsets_data[start_idx],
276               /*offsets=*/offsets_data + start_idx,
277               /*weights=*/nullptr,
278               /*scale_bias=*/nullptr,
279               /*normalize_by_lengths=*/false,
280               /*out=*/output_data_fp32 + start_idx * ddim);
281           for (int64_t i = start_idx; i < end_idx; i++) {
282             // Convert FP32 intermediate buffer result back to 16 bit for
283             // output dtype
284             if constexpr (std::is_same<data_t, at::Half>::value) {
285               // FP16
286               for (const auto d : c10::irange(ddim)) {
287                 (output_data + i * ddim)[d] =
288                     static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
289               }
290             } else {
291               // BF16
292               int64_t d = 0;
293               for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
294                 fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
295                 fVec temp_fp32_1 =
296                     fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
297                 convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
298                     .store(output_data + i * ddim + d);
299               }
300               for (; d < ddim; d++) {
301                 (output_data + i * ddim)[d] =
302                     static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
303               }
304             }
305           }
306         });
307 #endif
308   } else {
309     TORCH_CHECK(select_indices.numel() == add_indices.numel());
310     auto* src_data = src.const_data_ptr<data_t>();
311     auto* add_indices_data = add_indices.const_data_ptr<index_t>();
312     index_t* bag_size_data = nullptr;
313     if (bag_size.defined()) {
314       bag_size_data = bag_size.data_ptr<index_t>();
315     }
316     auto vocab_size = src.size(0);
317     auto src_stride0 = src.strides()[0];
318     auto src_stride1 = src.strides()[1];
319     auto output_stride0 = output.strides()[0];
320     auto output_stride1 = output.strides()[1];
321     auto numel = add_indices.numel();
322 
323     Tensor src_fp32 = at::empty({ddim}, src.options().dtype(at::kFloat));
324     auto* src_data_fp32 = src_fp32.mutable_data_ptr<float>();
325 
326     // Initialize the intermediate output buffer to be 0.
327     Tensor output_fp32 =
328         at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
329     auto* output_data_fp32 = output_fp32.data_ptr<float>();
330 
331     for (const auto i : c10::irange(numel)) {
332       // We can skip indices equal to padding_idx so they are not included in
333       // the reduction
334       auto idx = select_indices_data[i];
335       TORCH_CHECK(
336           idx >= 0 && idx < vocab_size,
337           "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
338           idx);
339       if (idx != padding_idx) {
340         // Copy src_data + src_stride0 * idx to src_data_fp32
341         for (const auto d : c10::irange(ddim)) {
342           src_data_fp32[d] = static_cast<float>(
343               (src_data + src_stride0 * idx)[d * src_stride1]);
344         }
345         at::native::cpublas::axpy<float>(
346             ddim,
347             1,
348             src_data_fp32,
349             1,
350             output_data_fp32 + ddim * add_indices_data[i],
351             1);
352 
353       } else if (bag_size_data) {
354         // Decrement bag_size to reflect that the index is padded
355         bag_size_data[add_indices_data[i]]--;
356       }
357     }
358     for (const auto i : c10::irange(output.size(0))) {
359       // Convert FP32 intermediate buffer result back to 16 bit for output
360       // dtype
361       for (const auto d : c10::irange(ddim)) {
362         (output_data + output_stride0 * i)[d * output_stride1] =
363             static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
364       }
365     }
366   }
367 }
368 template<typename data_t, typename index_t>
369 typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)370 index_select_add(const Tensor &select_indices,
371                              const Tensor &add_indices,
372                              const Tensor &src,
373                              Tensor &output,
374                              const Tensor& offsets,
375                              bool include_last_offset,
376                              Tensor &bag_size,
377                              index_t padding_idx,
378                              _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
379   int64_t ddim = src.size(1);
380   auto* select_indices_data = select_indices.const_data_ptr<index_t>();
381   auto* output_data = output.data_ptr<float>();
382 
383   if (is_fast_path_index_select(src, output, padding_idx)) {
384     auto src_contig = src.contiguous();
385     auto* src_data = src_contig.const_data_ptr<float>();
386     int64_t output_size = offsets.numel() - 1;
387     auto* offsets_data = offsets.const_data_ptr<index_t>();
388     std::vector<index_t> offsets_include_last;
389 
390     if (include_last_offset) {
391       output_size = offsets.numel() - 1;
392     } else {
393       output_size = offsets.numel();
394       offsets_include_last.resize(offsets.numel() + 1);
395       if (offsets.numel() > 0) {
396         std::memcpy(
397             offsets_include_last.data(),
398             offsets.const_data_ptr<index_t>(),
399             sizeof(index_t) * offsets.numel());
400       }
401       offsets_include_last[offsets.numel()] = select_indices.numel();
402       offsets_data = offsets_include_last.data();
403     }
404 
405 #ifdef USE_FBGEMM
406     auto kernel_fp32_index_t =
407       fbgemm_kernel_cache ?
408       fbgemm_kernel_cache->getCallback</* has_weight */ false, index_t, float>(ddim) :
409       fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
410         /* block_size */ddim,
411         /* has_weight */false,
412         /* normalize_by_lengths */false,
413         /* prefetch */16,
414         /* is_weight_positional */false,
415         /* use_offsets */true
416       );
417 #endif
418     at::parallel_for(
419         0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
420 #ifdef USE_FBGEMM
421           bool success = kernel_fp32_index_t(
422             /* output_size */end_idx - start_idx,
423             /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
424             /* data_size */src.size(0),
425             /* input */src_data,
426             /* indices */select_indices_data + offsets_data[start_idx],
427             /* offsets_or_lengths */offsets_data + start_idx,
428             /* weights */nullptr,
429             /* output */output_data + start_idx * ddim);
430           if (!success) {
431             fbgemm_spmdm_report_error_(
432                 end_idx - start_idx,
433                 offsets_data[end_idx] - offsets_data[start_idx],
434                 src.size(0),
435                 offsets_data + start_idx,
436                 select_indices_data + offsets_data[start_idx]);
437           }
438 #else
439           caffe2::EmbeddingLookupIdx(
440               /*block_size=*/ddim,
441               /*output_size=*/end_idx - start_idx,
442               /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
443               /*data_size=*/src.size(0),
444               /*input=*/src_data,
445               /*indices=*/select_indices_data + offsets_data[start_idx],
446               /*offsets=*/offsets_data + start_idx,
447               /*weights=*/nullptr,
448               /*scale_bias=*/nullptr,
449               /*normalize_by_lengths=*/false,
450               /*out=*/output_data + start_idx * ddim);
451 #endif
452         });
453   } else {
454     AT_ASSERT(select_indices.numel() == add_indices.numel());
455     auto* src_data = src.const_data_ptr<float>();
456     auto* add_indices_data = add_indices.const_data_ptr<index_t>();
457     index_t* bag_size_data = nullptr;
458     if (bag_size.defined()) {
459       bag_size_data = bag_size.data_ptr<index_t>();
460     }
461     auto vocab_size = src.size(0);
462     auto src_stride0 = src.strides()[0];
463     auto src_stride1 = src.strides()[1];
464     auto output_stride0 = output.strides()[0];
465     auto output_stride1 = output.strides()[1];
466     auto numel = add_indices.numel();
467     for (const auto i : c10::irange(numel)) {
468       // We can skip indices equal to padding_idx so they are not included in
469       // the reduction
470       auto idx = select_indices_data[i];
471       TORCH_CHECK(
472           idx >= 0 && idx < vocab_size,
473           "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
474           idx);
475       if (idx != padding_idx) {
476         at::native::cpublas::axpy<float>(
477             ddim,
478             1,
479             src_data + src_stride0 * idx,
480             src_stride1,
481             output_data + output_stride0 * add_indices_data[i],
482             output_stride1);
483       } else if (bag_size_data) {
484         // Decrement bag_size to reflect that the index is padded
485         bag_size_data[add_indices_data[i]]--;
486       }
487     }
488   }
489 }
490 
491 // This function fuses the following three fns:
492 // index_select (using select_indices as the index)
493 // mul (scaling by per_sample_weights)
494 // index_add (using add_indices as the index)
495 template <typename data_t, typename index_t>
496 static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
index_select_scale_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & scale,const Tensor & src,Tensor & output,const Tensor &,bool,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache *)497 index_select_scale_add(
498     const Tensor& select_indices,
499     const Tensor& add_indices,
500     const Tensor& scale,
501     const Tensor& src,
502     Tensor& output,
503     const Tensor& /*offsets*/,
504     bool /*include_last_offset*/,
505     Tensor& bag_size,
506     index_t padding_idx,
507     _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
508   AT_ASSERT(select_indices.numel() == add_indices.numel());
509   auto* add_indices_data = add_indices.const_data_ptr<index_t>();
510   auto* select_indices_data = select_indices.const_data_ptr<index_t>();
511   auto* src_data = src.const_data_ptr<data_t>();
512   auto* output_data = output.data_ptr<data_t>();
513   index_t* bag_size_data = nullptr;
514   if (bag_size.defined()) {
515     bag_size_data = bag_size.data_ptr<index_t>();
516   }
517   auto numel = add_indices.numel();
518   int64_t ddim = src.size(1);
519   auto vocab_size = src.size(0);
520   auto src_stride0 = src.strides()[0];
521   auto src_stride1 = src.strides()[1];
522   auto output_stride0 = output.strides()[0];
523   auto output_stride1 = output.strides()[1];
524 
525   auto* scale_data = scale.const_data_ptr<data_t>();
526   auto scale_stride = scale.strides()[0];
527 
528   for (const auto i : c10::irange(numel)) {
529     // We can skip indices equal to padding_idx so they are not included in
530     // the reduction
531     auto idx = select_indices_data[i];
532     TORCH_CHECK(
533         idx >= 0 && idx < vocab_size,
534         "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
535         idx);
536     if (idx != padding_idx) {
537       auto* src_base = src_data + src_stride0 * idx;
538       auto* output_base = output_data + output_stride0 * add_indices_data[i];
539       auto scale = scale_data[i * scale_stride];
540       for (const auto j : c10::irange(ddim)) {
541         output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
542       }
543     } else if (bag_size_data) {
544       // Decrement bag_size to reflect that the index is padded
545       bag_size_data[add_indices_data[i]]--;
546     }
547   }
548 }
549 
550 template <typename data_t, typename index_t>
551 typename std::enable_if<
552     std::is_same<data_t, at::Half>::value ||
553         std::is_same<data_t, at::BFloat16>::value,
554     void>::type
index_select_scale_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & scale,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)555 index_select_scale_add(
556     const Tensor& select_indices,
557     const Tensor& add_indices,
558     const Tensor& scale,
559     const Tensor& src,
560     Tensor& output,
561     const Tensor& offsets,
562     bool include_last_offset,
563     Tensor& bag_size,
564     index_t padding_idx,
565     _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
566   int64_t ddim = src.size(1);
567   auto* scale_data = scale.const_data_ptr<data_t>();
568   auto* select_indices_data = select_indices.const_data_ptr<index_t>();
569   auto* output_data = output.data_ptr<data_t>();
570 
571   if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
572     auto src_contig = src.contiguous();
573     auto* src_data = src_contig.const_data_ptr<data_t>();
574     int64_t output_size = offsets.numel() - 1;
575     auto* offsets_data = offsets.const_data_ptr<index_t>();
576     std::vector<index_t> offsets_include_last;
577 
578     if (include_last_offset) {
579       output_size = offsets.numel() - 1;
580     } else {
581       output_size = offsets.numel();
582       offsets_include_last.resize(offsets.numel() + 1);
583       std::memcpy(
584           offsets_include_last.data(),
585           offsets.const_data_ptr<index_t>(),
586           sizeof(index_t) * offsets.numel());
587       offsets_include_last[offsets.numel()] = select_indices.numel();
588       offsets_data = offsets_include_last.data();
589     }
590 
591     Tensor scale_fp32 = at::empty(scale.sizes(), scale.options().dtype(at::kFloat));
592     auto* scale_data_fp32 = scale_fp32.mutable_data_ptr<float>();
593 
594 #if defined(USE_FBGEMM)
595     constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
596     if constexpr (isbf16) {
597       fbgemm::Bfloat16ToFloat_simd(
598           reinterpret_cast<const fbgemm::bfloat16*>(scale_data),
599           scale_data_fp32,
600           scale_fp32.numel());
601     } else {
602       fbgemm::Float16ToFloat_simd(
603           reinterpret_cast<const fbgemm::float16*>(scale_data),
604           scale_data_fp32,
605           scale_fp32.numel());
606     }
607     auto kernel_16bit_index_t = fbgemm_kernel_cache
608         ? fbgemm_kernel_cache
609               ->getCallback</* has_weight */ true, index_t, uint16_t>(ddim)
610         : fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
611               /* block_size */ ddim,
612               /* has_weight */ true,
613               /* normalize_by_lengths */ false,
614               /* prefetch */ 16,
615               /* is_weight_positional */ false,
616               /* use_offsets */ true,
617               /* is_bf16_out */ isbf16,
618               /* is_bf16_in */ isbf16);
619     at::parallel_for(
620         0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
621           bool success = kernel_16bit_index_t(
622               /* output_size */ end_idx - start_idx,
623               /* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
624               /* data_size */ src.size(0),
625               /* input */ reinterpret_cast<const uint16_t*>(src_data),
626               /* indices */ select_indices_data + offsets_data[start_idx],
627               /* offsets_or_lengths */ offsets_data + start_idx,
628               /* weights */ scale_data_fp32 + offsets_data[start_idx],
629               /* output */
630               reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
631           if (!success) {
632             fbgemm_spmdm_report_error_(
633                 end_idx - start_idx,
634                 offsets_data[end_idx] - offsets_data[start_idx],
635                 src.size(0),
636                 offsets_data + start_idx,
637                 select_indices_data + offsets_data[start_idx]);
638           }
639         });
640 #else
641     // Initialize the intermediate output buffer to be 0.
642     Tensor output_fp32 =
643         at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
644     auto* output_data_fp32 = output_fp32.data_ptr<float>();
645     for (const auto i : c10::irange(scale.numel())) {
646       scale_data_fp32[i] = static_cast<float>(scale_data[i]);
647     }
648     using bVec = vec::Vectorized<BFloat16>;
649     using fVec = vec::Vectorized<float>;
650     at::parallel_for(
651         0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
652           caffe2::EmbeddingLookupIdx(
653               /*block_size=*/ddim,
654               /*output_size=*/end_idx - start_idx,
655               /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
656               /*data_size=*/src.size(0),
657               /*input=*/src_data,
658               /*indices=*/select_indices_data + offsets_data[start_idx],
659               /*offsets=*/offsets_data + start_idx,
660               /*weights=*/scale_data_fp32 + offsets_data[start_idx],
661               /*scale_bias=*/nullptr,
662               /*normalize_by_lengths=*/false,
663               /*out=*/output_data_fp32 + start_idx * ddim);
664           for (int64_t i = start_idx; i < end_idx; i++) {
665             // Convert FP32 intermediate buffer result back to 16 bit for
666             // output dtype
667             if constexpr (std::is_same<data_t, at::Half>::value) {
668               // FP16
669               for (const auto d : c10::irange(ddim)) {
670                 (output_data + i * ddim)[d] =
671                     static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
672               }
673             } else {
674               // BF16
675               int64_t d = 0;
676               for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
677                 fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
678                 fVec temp_fp32_1 =
679                     fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
680                 convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
681                     .store(output_data + i * ddim + d);
682               }
683               for (; d < ddim; d++) {
684                 (output_data + i * ddim)[d] =
685                     static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
686               }
687             }
688           }
689         });
690 #endif
691   } else {
692     AT_ASSERT(select_indices.numel() == add_indices.numel());
693     auto* src_data = src.const_data_ptr<data_t>();
694     auto* add_indices_data = add_indices.const_data_ptr<index_t>();
695     index_t* bag_size_data = nullptr;
696     if (bag_size.defined()) {
697       bag_size_data = bag_size.data_ptr<index_t>();
698     }
699     auto vocab_size = src.size(0);
700     auto src_stride0 = src.strides()[0];
701     auto src_stride1 = src.strides()[1];
702     auto output_stride0 = output.strides()[0];
703     auto output_stride1 = output.strides()[1];
704     auto scale_stride = scale.strides()[0];
705     auto numel = add_indices.numel();
706 
707     // Initialize the intermediate output buffer to be 0.
708     Tensor output_fp32 =
709         at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
710     auto* output_data_fp32 = output_fp32.data_ptr<float>();
711 
712     for (const auto i : c10::irange(numel)) {
713       // We can skip indices equal to padding_idx so they are not included in
714       // the reduction
715       auto idx = select_indices_data[i];
716       TORCH_CHECK(
717           idx >= 0 && idx < vocab_size,
718           "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
719           idx);
720       if (idx != padding_idx) {
721         auto* src_base = src_data + src_stride0 * idx;
722         auto* output_base_fp32 = output_data_fp32 + ddim * add_indices_data[i];
723         auto scale = scale_data[i * scale_stride];
724         for (const auto j : c10::irange(ddim)) {
725           output_base_fp32[j] += static_cast<float>(src_base[j * src_stride1]) *
726               static_cast<float>(scale);
727         }
728       } else if (bag_size_data) {
729         // Decrement bag_size to reflect that the index is padded
730         bag_size_data[add_indices_data[i]]--;
731       }
732     }
733     for (const auto i : c10::irange(output.size(0))) {
734       // Convert FP32 intermediate buffer result back to 16 bit for output
735       // dtype
736       for (const auto d : c10::irange(ddim)) {
737         (output_data + output_stride0 * i)[d * output_stride1] =
738             static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
739       }
740     }
741   }
742 }
743 template<typename data_t, typename index_t>
744 typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_scale_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & scale,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)745 index_select_scale_add(const Tensor &select_indices,
746                                           const Tensor &add_indices,
747                                           const Tensor &scale,
748                                           const Tensor &src,
749                                           Tensor &output,
750                                           const Tensor& offsets,
751                                           bool include_last_offset,
752                                           Tensor &bag_size,
753                                           index_t padding_idx,
754                                           _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
755   int64_t ddim = src.size(1);
756   auto* scale_data = scale.const_data_ptr<float>();
757   auto* select_indices_data = select_indices.const_data_ptr<index_t>();
758   auto* output_data = output.data_ptr<float>();
759 
760   if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
761     auto src_contig = src.contiguous();
762     auto* src_data = src_contig.const_data_ptr<float>();
763     int64_t output_size = offsets.numel() - 1;
764     auto* offsets_data = offsets.const_data_ptr<index_t>();
765     std::vector<index_t> offsets_include_last;
766 
767     if (include_last_offset) {
768       output_size = offsets.numel() - 1;
769     } else {
770       output_size = offsets.numel();
771       offsets_include_last.resize(offsets.numel() + 1);
772       std::memcpy(
773           offsets_include_last.data(),
774           offsets.const_data_ptr<index_t>(),
775           sizeof(index_t) * offsets.numel());
776       offsets_include_last[offsets.numel()] = select_indices.numel();
777       offsets_data = offsets_include_last.data();
778     }
779 
780 #ifdef USE_FBGEMM
781     auto kernel_fp32_index_t =
782       fbgemm_kernel_cache ?
783       fbgemm_kernel_cache->getCallback</* has_weight */ true, index_t, float>(ddim) :
784       fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
785         /* block_size */ddim,
786         /* has_weight */true,
787         /* normalize_by_lengths */false,
788         /* prefetch */16,
789         /* is_weight_positional */false,
790         /* use_offsets */true
791       );
792 #endif
793     at::parallel_for(
794         0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
795 #ifdef USE_FBGEMM
796           bool success = kernel_fp32_index_t(
797             /* output_size */end_idx - start_idx,
798             /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
799             /* data_size */src.size(0),
800             /* input */src_data,
801             /* indices */select_indices_data + offsets_data[start_idx],
802             /* offsets_or_lengths */offsets_data + start_idx,
803             /* weights */scale_data + offsets_data[start_idx],
804             /* output */output_data + start_idx * ddim);
805           if (!success) {
806             fbgemm_spmdm_report_error_(
807                 end_idx - start_idx,
808                 offsets_data[end_idx] - offsets_data[start_idx],
809                 src.size(0),
810                 offsets_data + start_idx,
811                 select_indices_data + offsets_data[start_idx]);
812           }
813 #else
814           caffe2::EmbeddingLookupIdx(
815               /*block_size=*/ddim,
816               /*output_size=*/end_idx - start_idx,
817               /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
818               /*data_size=*/src.size(0),
819               /*input=*/src_data,
820               /*indices=*/select_indices_data + offsets_data[start_idx],
821               /*offsets=*/offsets_data + start_idx,
822               /*weights=*/scale_data + offsets_data[start_idx],
823               /*scale_bias=*/nullptr,
824               /*normalize_by_lengths=*/false,
825               /*out=*/output_data + start_idx * ddim);
826 #endif
827         });
828   } else {
829     AT_ASSERT(select_indices.numel() == add_indices.numel());
830     auto* src_data = src.const_data_ptr<float>();
831     auto* add_indices_data = add_indices.const_data_ptr<index_t>();
832     index_t* bag_size_data = nullptr;
833     if (bag_size.defined()) {
834       bag_size_data = bag_size.data_ptr<index_t>();
835     }
836     auto vocab_size = src.size(0);
837     auto src_stride0 = src.strides()[0];
838     auto src_stride1 = src.strides()[1];
839     auto output_stride0 = output.strides()[0];
840     auto output_stride1 = output.strides()[1];
841     auto scale_stride = scale.strides()[0];
842     auto numel = add_indices.numel();
843 
844 
845     for (const auto i : c10::irange(numel)) {
846       // We can skip indices equal to padding_idx so they are not included in
847       // the reduction
848       auto idx = select_indices_data[i];
849       TORCH_CHECK(
850           idx >= 0 && idx < vocab_size,
851           "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
852           idx);
853       if (idx != padding_idx) {
854         auto* src_base = src_data + src_stride0 * idx;
855         auto* output_base = output_data + output_stride0 * add_indices_data[i];
856         auto scale = scale_data[i * scale_stride];
857         for (const auto j : c10::irange(ddim)) {
858           output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
859         }
860       } else if (bag_size_data) {
861         // Decrement bag_size to reflect that the index is padded
862         bag_size_data[add_indices_data[i]]--;
863       }
864     }
865   }
866 }
867 
868 }  // namespace
869 
check_arguments(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,bool include_last_offset)870 void check_arguments(
871     const Tensor& weight,
872     const Tensor& indices,
873     const Tensor& offsets,
874     const int64_t mode,
875     const std::optional<Tensor>& per_sample_weights,
876     bool include_last_offset) {
877   auto indices_arg = TensorArg(indices, "indices", 1);
878   checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
879   auto offsets_arg = TensorArg(offsets, "offsets", 1);
880   checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
881   checkSameType("embedding_bag", indices_arg, offsets_arg);
882   auto weight_arg = TensorArg(weight, "weight", 1);
883   checkScalarTypes(
884       "embedding_bag", weight_arg, {kHalf, kBFloat16, kFloat, kDouble});
885 
886   AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() {
887     if (offsets.size(0) > 0) {
888       index_t offset_0 = offsets.const_data_ptr<index_t>()[0];
889       index_t offset_n = offsets.const_data_ptr<index_t>()[offsets.size(0)-1];
890       TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
891                                 "in the mini-batch has to start from position 0. "
892                                 "However, got ", offsets[0]);
893       TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
894                   "be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
895                   offset_n);
896     }
897   });
898 
899   if (per_sample_weights.has_value() && per_sample_weights.value().defined()) {
900     TORCH_CHECK(
901         mode == EmbeddingBagMode::SUM,
902         "embedding_bag: per_sample_weights only supported with mode='sum'");
903     auto per_input_weights_arg = TensorArg(
904         per_sample_weights.value(),"per_sample_weights", 1);
905     checkSameType("embedding_bag", weight_arg, per_input_weights_arg);
906     TORCH_CHECK(per_sample_weights.value().dim() == 1);
907     TORCH_CHECK(per_sample_weights.value().numel() == indices.numel());
908   }
909 
910   if (include_last_offset) {
911     TORCH_CHECK(
912         offsets.size(0) >= 1,
913         "include_last_offset: number of offset should be at least 1");
914   }
915 }
916 
make_bag_size_out(Tensor & bag_size_out,const Tensor & offsets,const Tensor & indices,const int64_t mode,const bool include_last_offset,const bool requires_grad)917 void make_bag_size_out(
918     Tensor& bag_size_out,
919     const Tensor& offsets,
920     const Tensor& indices,
921     const int64_t mode,
922     const bool include_last_offset,
923     const bool requires_grad) {
924   if (requires_grad || mode == EmbeddingBagMode::MEAN ||
925       mode == EmbeddingBagMode::MAX) {
926     auto num_bags = offsets.size(0) - (include_last_offset ? 1 : 0);
927     at::native::resize_(bag_size_out, {num_bags}, std::nullopt);
928     // Compute this for EmbeddingBagMode::MEAN and EmbeddingBagMode::MAX (latter
929     // needed for backwards)
930     if (num_bags != 1) {
931       bag_size_out.slice(0, 0, bag_size_out.size(0) - 1, 1) =
932           offsets.slice(0, 1, num_bags, 1) -
933           offsets.slice(0, 0, num_bags - 1, 1);
934     }
935     if (num_bags > 0) {
936       bag_size_out[-1] = indices.size(0) - offsets[num_bags - 1];
937     }
938   } else {
939     at::native::resize_(bag_size_out, offsets.sizes(), std::nullopt);
940   }
941 }
942 
make_max_indices_out(Tensor & max_indices_out,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const Tensor & bag_size,const int64_t mode,bool include_last_offset)943 void make_max_indices_out(
944     Tensor& max_indices_out,
945     const Tensor& weight,
946     const Tensor& indices,
947     const Tensor& offsets,
948     const Tensor& bag_size,
949     const int64_t mode,
950     bool include_last_offset) {
951   int64_t numBags = offsets.size(0);
952   if (mode == EmbeddingBagMode::MAX) {
953     if (include_last_offset) {
954       TORCH_CHECK(
955         numBags >= 1, "include_last_offset: numBags should be at least 1");
956       numBags -= 1;
957     }
958     at::native::resize_(max_indices_out, {numBags, weight.sizes()[1]}, std::nullopt);
959     at::native::zero_(max_indices_out);
960   } else {
961     at::native::resize_(max_indices_out, bag_size.sizes(), std::nullopt);
962   }
963 }
964 
make_offset2bag_out(Tensor & offset2bag,Tensor & output,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,const int64_t padding_idx)965 void make_offset2bag_out(
966     Tensor& offset2bag,
967     Tensor& output,
968     const Tensor& weight,
969     const Tensor& indices,
970     const Tensor& offsets,
971     const int64_t mode,
972     const std::optional<Tensor>& per_sample_weights,
973     const int64_t padding_idx) {
974   // To save compute, if we are going to go down the fast path case for the 'sum'
975   // mode, we skip calculating offset2bag, since it is not going to be used.
976   bool fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx);
977 
978   if (mode == EmbeddingBagMode::MEAN || mode == EmbeddingBagMode::MAX ||
979       !fast_path_sum) {
980     at::native::resize_(offset2bag, {indices.size(0) + 1}, std::nullopt);
981     at::native::zero_(offset2bag);
982 
983     int64_t offsets_size = offsets.size(0);
984     bool include_last_offset = (output.size(0) == offsets_size - 1);
985     // when include_last_offset is true, ignore the last index in offset.
986     // fix segfault when include_last_offset is true and offsets[-1] != indices.size(0)
987     // see https://github.com/pytorch/pytorch/issues/89677 for more details.
988     Tensor _offsets = offsets;
989     if (include_last_offset) {
990       _offsets = offsets.narrow(0, 0, offsets_size - 1);
991     }
992     make_offset2bag(_offsets, offset2bag);
993     at::native::resize_(offset2bag, {indices.size(0)}, std::nullopt);
994     // only initialize output in slow path
995     at::native::zero_(output);
996   }
997 }
998 
make_bag_size(const Tensor & offsets,const Tensor & indices,const int64_t mode,const bool include_last_offset,const bool requires_grad)999 static Tensor make_bag_size(
1000     const Tensor& offsets,
1001     const Tensor& indices,
1002     const int64_t mode,
1003     const bool include_last_offset,
1004     const bool requires_grad) {
1005   Tensor bag_size = at::empty(offsets.sizes(), offsets.options());
1006   make_bag_size_out(bag_size, offsets, indices, mode, include_last_offset, requires_grad);
1007   return bag_size;
1008 }
1009 
make_max_indices(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const Tensor & bag_size,const int64_t mode,bool include_last_offset)1010 static Tensor make_max_indices(
1011     const Tensor& weight,
1012     const Tensor& indices,
1013     const Tensor& offsets,
1014     const Tensor& bag_size,
1015     const int64_t mode,
1016     bool include_last_offset) {
1017   Tensor max_indices = at::empty(bag_size.sizes(), offsets.options());
1018   make_max_indices_out(max_indices, weight, indices, offsets, bag_size, mode, include_last_offset);
1019   return max_indices;
1020 }
1021 
make_offset2bag(Tensor & output,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,const int64_t padding_idx)1022 static Tensor make_offset2bag(
1023     Tensor& output,
1024     const Tensor& weight,
1025     const Tensor& indices,
1026     const Tensor& offsets,
1027     const int64_t mode,
1028     const std::optional<Tensor>& per_sample_weights,
1029     const int64_t padding_idx) {
1030   Tensor offset2bag = at::empty({0}, offsets.options());
1031   make_offset2bag_out(offset2bag, output, weight, indices, offsets, mode, per_sample_weights, padding_idx);
1032   return offset2bag;
1033 }
1034 
apply_bag_size(const int64_t mode,Tensor & output,const Tensor & bag_size)1035 static Tensor apply_bag_size(
1036     const int64_t mode,
1037     Tensor &output,
1038     const Tensor &bag_size) {
1039   if (mode == EmbeddingBagMode::MEAN) {
1040     auto bag_size_ = at::max(bag_size, at::ones_like(bag_size, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
1041                          .to(output.options())
1042                          .unsqueeze(1)
1043                          .expand_as(output);
1044     output /= bag_size_;
1045   }
1046   return output;
1047 }
1048 
apply_bag_size_backward(const int64_t mode,Tensor & output,const Tensor & offset2bag,const Tensor & bag_size)1049 static Tensor apply_bag_size_backward(
1050     const int64_t mode,
1051     Tensor &output,
1052     const Tensor &offset2bag,
1053     const Tensor &bag_size) {
1054   if (mode == EmbeddingBagMode::MEAN) {
1055     auto inv_bag_size_ = (1 / bag_size.to(output.options()))
1056                            .unsqueeze(1)
1057                            .index_select(0, offset2bag);
1058     output *= inv_bag_size_;
1059   }
1060   return output;
1061 }
1062 
1063 template <typename scalar_t>
embedding_bag_cpu_max_out(Tensor * max_indices,const Tensor & weight,const Tensor & indices,const Tensor & offset2bag,const Tensor & output,bool include_last_offset,Tensor & bag_size,int64_t padding_idx)1064 void embedding_bag_cpu_max_out(
1065     Tensor* max_indices,
1066     const Tensor& weight,
1067     const Tensor& indices,
1068     const Tensor& offset2bag,
1069     const Tensor& output,
1070     bool include_last_offset,
1071     Tensor& bag_size,
1072     int64_t padding_idx) {
1073   int64_t numIndices = indices.numel();
1074   int64_t featureSize = weight.size(1);
1075   int64_t vocab_size = weight.size(0);
1076   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max_out", [&] {
1077     auto* indices_data = indices.const_data_ptr<index_t>();
1078     auto* offset2bag_data = offset2bag.data_ptr<index_t>();
1079 
1080     index_t* max_indices_data = nullptr;
1081     int64_t max_indices_stride = 0;
1082     if (max_indices) {
1083       max_indices_data = max_indices->data_ptr<index_t>();
1084       max_indices_stride = max_indices->strides()[0];
1085     }
1086 
1087     auto* weight_data = weight.const_data_ptr<scalar_t>();
1088     auto* output_data = output.data_ptr<scalar_t>();
1089     auto* bag_size_data = bag_size.data_ptr<index_t>();
1090     auto weight_stride0 = weight.strides()[0];
1091     auto weight_stride1 = weight.strides()[1];
1092     auto output_stride = output.strides()[0];
1093     int64_t numBags = bag_size.size(0);
1094     std::vector<bool> bag_empty(numBags, true);
1095 
1096     for (const auto i : c10::irange(numIndices)) {
1097       auto bag = offset2bag_data[i];
1098       auto word_idx = indices_data[i];
1099       TORCH_CHECK(
1100           word_idx >= 0 && word_idx < vocab_size,
1101           "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
1102           word_idx);
1103       if (word_idx != static_cast<index_t>(padding_idx)) {
1104         bool is_first_for_bag = bag_empty[bag];
1105         for (const auto dim : c10::irange(featureSize)) {
1106           auto& current_item = output_data[output_stride * bag + dim];
1107           auto weight_item =
1108               weight_data[weight_stride0 * word_idx + dim * weight_stride1];
1109 
1110           if (is_first_for_bag || (weight_item > current_item)) {
1111             current_item = weight_item;
1112             if (max_indices_data) {
1113               max_indices_data[max_indices_stride * bag + dim] = word_idx;
1114             }
1115           }
1116         }
1117         if (is_first_for_bag) {
1118           bag_empty[bag] = false;
1119         }
1120       } else {
1121         // Decrement bag_size to reflect that the index is padded
1122         bag_size_data[bag]--;
1123       }
1124     }
1125   });
1126 }
1127 
_embedding_bag_cpu_impl_out(Tensor & output,Tensor & offset2bag,Tensor & bag_size,Tensor * max_indices,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,bool include_last_offset,int64_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)1128 void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
1129                             Tensor& bag_size, Tensor* max_indices,
1130                             const Tensor &weight, const Tensor &indices,
1131                             const Tensor &offsets, const int64_t mode,
1132                             const std::optional<Tensor>& per_sample_weights,
1133                             bool include_last_offset, int64_t padding_idx, _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
1134   if (mode == EmbeddingBagMode::MEAN || mode == EmbeddingBagMode::SUM) {
1135     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_no_grad_cpu_out",
1136       [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() {
1137       AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_no_grad_cpu_out",
1138         [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() {
1139         if (per_sample_weights.has_value() && per_sample_weights.value().defined()) {
1140           TORCH_INTERNAL_ASSERT(mode == EmbeddingBagMode::SUM);
1141           index_select_scale_add<scalar_t, index_t>(
1142             indices, offset2bag, per_sample_weights.value(), weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache);
1143         } else {
1144           index_select_add<scalar_t, index_t>(indices, offset2bag, weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache);
1145         }
1146       });
1147     });
1148     apply_bag_size(mode, output, bag_size);
1149     if (mode == EmbeddingBagMode::SUM) {
1150       // make bag_size output deterministic
1151       at::native::zero_(bag_size);
1152     }
1153     if (max_indices) {
1154       max_indices->copy_(bag_size);
1155     }
1156   } else { // EmbeddingBagMode::MAX
1157     AT_DISPATCH_FLOATING_TYPES_AND2(
1158         at::ScalarType::Half,
1159         at::ScalarType::BFloat16,
1160         weight.scalar_type(),
1161         "embedding_bag_cpu_max_out",
1162         [&]() {
1163           embedding_bag_cpu_max_out<scalar_t>(
1164               max_indices,
1165               weight,
1166               indices,
1167               offset2bag,
1168               output,
1169               include_last_offset,
1170               bag_size,
1171               padding_idx);
1172         });
1173   }
1174 }
1175 
1176 // Assumes all input tensors except for `weight` are contiguous.
1177 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
_embedding_bag_cpu_impl(const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const int64_t mode,const Tensor & per_sample_weights,bool include_last_offset,int64_t padding_idx,bool requires_grad)1178 static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
1179     const Tensor& weight,
1180     const Tensor& indices_,
1181     const Tensor& offsets_,
1182     const int64_t mode,
1183     const Tensor& per_sample_weights,
1184     bool include_last_offset,
1185     int64_t padding_idx,
1186     bool requires_grad) {
1187   TORCH_CHECK(indices_.dim() == 1 || indices_.dim() == 2,
1188       "input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
1189       indices_.dim());
1190   if (indices_.dim() == 1) {
1191     TORCH_CHECK(offsets_.dim() == 1,
1192         "offsets has to be a 1D Tensor, but got Tensor of dimension ",
1193         offsets_.dim());
1194   }
1195   TORCH_CHECK(weight.dim() == 2,
1196       "weight has to be a 2D Tensor, but got Tensor of dimension ",
1197       weight.dim());
1198   auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1199   const auto& indices = *indicesMaybeOwned;
1200   const auto& offsets = *offsetsMaybeOwned;
1201   check_arguments(weight, indices, offsets, mode, per_sample_weights, include_last_offset);
1202 
1203   Tensor output = at::empty(
1204       {include_last_offset ? offsets.size(0) - 1 : offsets.size(0),
1205        weight.sizes()[1]},
1206       weight.options());
1207 
1208   Tensor offset2bag = make_offset2bag(output, weight, indices, offsets, mode, per_sample_weights, padding_idx);
1209 
1210   Tensor bag_size = make_bag_size(offsets, indices, mode, include_last_offset, requires_grad);
1211 
1212   Tensor max_indices = make_max_indices(weight, indices, offsets, bag_size, mode, include_last_offset);
1213 
1214   _embedding_bag_cpu_impl_out(output, offset2bag,
1215                           bag_size, &max_indices,
1216                           weight, indices, offsets,
1217                           mode, per_sample_weights,
1218                           include_last_offset, padding_idx);
1219 
1220   return std::make_tuple(std::move(output), std::move(offset2bag), std::move(bag_size), std::move(max_indices));
1221 }
1222 
1223 // embedding_bag wrapper to enforce contiguity in tensors other than `weight`.
1224 // This is created to save extra `.contiguous()` call in backward.
1225 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
1226 std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,std::optional<int64_t> padding_idx_opt)1227 embedding_bag(const Tensor &weight, const Tensor &indices,
1228               const Tensor &offsets, const bool scale_grad_by_freq,
1229               const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1230               bool include_last_offset, std::optional<int64_t> padding_idx_opt) {
1231   // See [Note: hacky wrapper removal for optional tensor]
1232   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1233   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1234   int64_t padding_idx = -1;
1235 
1236   if (padding_idx_opt.has_value()) {
1237     auto num_embeddings = weight.size(0);
1238     padding_idx = padding_idx_opt.value();
1239     TORCH_CHECK(
1240       (padding_idx >= -num_embeddings) && (padding_idx < num_embeddings),
1241       "padding_idx must be within the number of embeddings, -", num_embeddings,
1242       " through ", num_embeddings - 1, ", but got ", padding_idx);
1243     padding_idx = maybe_wrap_dim(padding_idx, weight.size(0));
1244   }
1245   std::tuple<Tensor, Tensor, Tensor, Tensor> out;
1246   if (!weight.requires_grad() && !weight._fw_grad(/*level=*/0).defined()) {
1247     out = at::_embedding_bag_forward_only(
1248       weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq,
1249       mode, sparse, per_sample_weights, include_last_offset, padding_idx);
1250   } else {
1251     out = at::_embedding_bag(
1252       weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq,
1253       mode, sparse, per_sample_weights, include_last_offset, padding_idx);
1254   }
1255   return out;
1256 };
1257 
1258 std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset)1259 embedding_bag(const Tensor &weight, const Tensor &indices,
1260               const Tensor &offsets, const bool scale_grad_by_freq,
1261               const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1262               bool include_last_offset) {
1263   return at::native::embedding_bag(weight, indices, offsets, scale_grad_by_freq,
1264       mode, sparse, per_sample_weights_opt, include_last_offset, std::nullopt);
1265 }
1266 
1267 // Assumes all input tensors except for `weight` are contiguous.
1268 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
1269 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_forward_only_cpu(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)1270 _embedding_bag_forward_only_cpu(const Tensor &weight, const Tensor &indices,
1271                   const Tensor &offsets, const bool scale_grad_by_freq,
1272                   const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt, bool include_last_offset,
1273                   int64_t padding_idx) {
1274   // See [Note: hacky wrapper removal for optional tensor]
1275   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1276   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1277   std::ignore = scale_grad_by_freq;
1278   std::ignore = sparse;
1279   return _embedding_bag_cpu_impl(
1280       weight,
1281       indices,
1282       offsets,
1283       mode,
1284       per_sample_weights,
1285       include_last_offset,
1286       padding_idx,
1287       /*requires_grad=*/false);
1288 }
1289 
1290 // Assumes all input tensors except for `weight` are contiguous.
1291 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
1292 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cpu(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)1293 _embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
1294                   const Tensor &offsets, const bool scale_grad_by_freq,
1295                   const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt, bool include_last_offset,
1296                   int64_t padding_idx) {
1297   // See [Note: hacky wrapper removal for optional tensor]
1298   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1299   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1300 
1301   std::ignore = scale_grad_by_freq;
1302   std::ignore = sparse;
1303   return _embedding_bag_cpu_impl(
1304       weight,
1305       indices,
1306       offsets,
1307       mode,
1308       per_sample_weights,
1309       include_last_offset,
1310       padding_idx,
1311       /*requires_grad=*/true);
1312 }
1313 
_embedding_bag_cpu_out(at::Tensor & output,at::Tensor & offset2bag,at::Tensor & bag_size,at::Tensor * p_max_indices,const at::Tensor & weight,const at::Tensor & indices_,const at::Tensor & offsets_,const bool,const int64_t mode,const bool,const std::optional<at::Tensor> & per_sample_weights,const bool include_last_offset,const std::optional<int64_t> & padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)1314 void _embedding_bag_cpu_out(
1315     at::Tensor& output,
1316     at::Tensor& offset2bag,
1317     at::Tensor& bag_size,
1318     at::Tensor* p_max_indices,
1319     const at::Tensor& weight,
1320     const at::Tensor& indices_,
1321     const at::Tensor& offsets_,
1322     const bool /* scale_grad_by_freq */,
1323     const int64_t mode,
1324     const bool /* sparse */,
1325     const std::optional<at::Tensor>& per_sample_weights,
1326     const bool include_last_offset,
1327     const std::optional<int64_t>& padding_idx,
1328     _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
1329   auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1330   const auto& indices = *indicesMaybeOwned;
1331   const auto& offsets = *offsetsMaybeOwned;
1332   at::native::check_arguments(
1333       weight, indices, offsets, mode, per_sample_weights, include_last_offset);
1334 
1335   at::native::make_offset2bag_out(
1336       offset2bag,
1337       output,
1338       weight,
1339       indices,
1340       offsets,
1341       mode,
1342       per_sample_weights,
1343       padding_idx.value_or(-1));
1344 
1345   at::native::make_bag_size_out(
1346       bag_size, offsets, indices, mode, include_last_offset, false);
1347 
1348   if (p_max_indices) {
1349     at::native::make_max_indices_out(
1350         *p_max_indices,
1351         weight,
1352         indices,
1353         offsets,
1354         bag_size,
1355         mode,
1356         include_last_offset);
1357   }
1358 
1359   at::native::_embedding_bag_cpu_impl_out(
1360       output,
1361       offset2bag,
1362       bag_size,
1363       p_max_indices,
1364       weight,
1365       indices,
1366       offsets,
1367       mode,
1368       per_sample_weights,
1369       include_last_offset,
1370       padding_idx.value_or(-1),
1371       fbgemm_kernel_cache);
1372 }
1373 
_embedding_bag_backward(const Tensor & grad,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,const Tensor & bag_size_,const Tensor & max_indices_,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)1374 Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_,
1375                               const Tensor &offsets_,
1376                               const Tensor &offset2bag,
1377                               const Tensor &bag_size_,
1378                               const Tensor &max_indices_,
1379                               int64_t num_weights,
1380                               bool scale_grad_by_freq, int64_t mode,
1381                               bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1382                               int64_t padding_idx) {
1383     return at::native::_embedding_bag_backward_symint(
1384         grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_opt, padding_idx);
1385 }
1386 
1387 // Assumes all input tensors are contiguous.
1388 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
_embedding_bag_backward_symint(const Tensor & grad,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,const Tensor & bag_size_,const Tensor & max_indices_,c10::SymInt num_weights,bool scale_grad_by_freq,int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)1389 Tensor _embedding_bag_backward_symint(const Tensor &grad, const Tensor &indices_,
1390                               const Tensor &offsets_,
1391                               const Tensor &offset2bag,
1392                               const Tensor &bag_size_,
1393                               const Tensor &max_indices_,
1394                               c10::SymInt num_weights,
1395                               bool scale_grad_by_freq, int64_t mode,
1396                               bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1397                               int64_t padding_idx) {
1398   // See [Note: hacky wrapper removal for optional tensor]
1399   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1400   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1401 
1402   auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1403   const auto& indices = *indicesMaybeOwned;
1404   const auto& offsets = *offsetsMaybeOwned;
1405   auto indices_arg = TensorArg(indices, "indices", 1);
1406   checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
1407   checkContiguous("embedding_bag", indices_arg);
1408   auto offsets_arg = TensorArg(offsets, "offsets", 1);
1409   checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
1410   checkSameType("embedding_bag", indices_arg, offsets_arg);
1411   checkContiguous("embedding_bag", offsets_arg);
1412 
1413   Tensor offset2bag_;
1414   if (indices.sym_numel() != 0 && offset2bag.sym_numel() == 0) {
1415     offset2bag_ = offsets.new_zeros(
1416       {indices.size(0) + 1}, offsets.options()); // offset2bag = [0 0 0 0 0]
1417 
1418     make_offset2bag(offsets, offset2bag_);
1419     // For Composite Compliance, if `offset2bag_` is CCT
1420     // then we can't call `resize_`. Instead we call `narrow`
1421     // to slice the tensor.
1422     if (isTensorSubclassLike(offset2bag_)) {
1423       offset2bag_ = offset2bag_.narrow(0, 0, indices.size(0));
1424     } else {
1425       offset2bag_.resize_({indices.size(0)});
1426     }
1427   } else {
1428     auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
1429     checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
1430     checkContiguous("embedding_bag", offset2bag_arg);
1431     offset2bag_ = offset2bag;
1432   }
1433 
1434   if (sparse) {
1435     return at::_embedding_bag_sparse_backward_symint(
1436         grad, indices, offsets, offset2bag_, bag_size_, std::move(num_weights),
1437         scale_grad_by_freq, mode, per_sample_weights, padding_idx);
1438   } else {
1439     return at::_embedding_bag_dense_backward_symint(
1440         grad, indices, offset2bag_, bag_size_, max_indices_, std::move(num_weights),
1441         scale_grad_by_freq, mode, per_sample_weights, padding_idx);
1442   }
1443 }
1444 
_embedding_bag_dense_backward_cpu_max(const Tensor & grad,const Tensor & bag_size,const Tensor & max_indices,int64_t num_weights)1445 static Tensor _embedding_bag_dense_backward_cpu_max(
1446     const Tensor& grad,
1447     const Tensor& bag_size,
1448     const Tensor& max_indices,
1449     int64_t num_weights) {
1450   AT_ASSERT(max_indices.defined());
1451   auto index_grad_weight =
1452       at::zeros({num_weights, grad.sizes()[1]}, grad.options());
1453   auto nonempty_max_indices = max_indices.index_select(0, bag_size.nonzero().view(-1));
1454   auto nonempty_grad = grad.index_select(0, bag_size.nonzero().view(-1));
1455 
1456   for (const auto dim : c10::irange(grad.sizes()[1])) {
1457     index_grad_weight.select(1, dim).index_add_(
1458       0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim));
1459   }
1460   return index_grad_weight;
1461 }
1462 
1463 template<typename index_t>
compute_counts(int64_t num_weights,const index_t * indices_data,int64_t indices_length)1464 static std::vector<index_t> compute_counts(
1465     int64_t num_weights,
1466     const index_t* indices_data,
1467     int64_t indices_length) {
1468   std::vector<index_t> counts(num_weights, 0);
1469   for (const auto i : c10::irange(indices_length)) {
1470     counts[indices_data[i]]++;
1471   }
1472   return counts;
1473 }
1474 
1475 // counts_uniq stores the index of the NEXT unique element
1476 // of the (sorted) indices vector.
1477 //
1478 // For example:
1479 // indices: [0, 0, 0, 1, 3, 3, 4]
1480 // counts: [3, 1, 0, 2, 1, 0]
1481 // counts_uniq: [3, 4, 6, 7]
1482 //
1483 // The unique indices can be found at index 0, 3, 4, 6.
1484 template<typename index_t>
compute_counts_uniq(int64_t num_weights,const index_t * indices_data,int64_t indices_length,const std::vector<index_t> & counts)1485 static std::vector<index_t> compute_counts_uniq(
1486     int64_t num_weights,
1487     const index_t* indices_data,
1488     int64_t indices_length,
1489     const std::vector<index_t>& counts) {
1490   std::vector<index_t> counts_uniq;
1491   counts_uniq.reserve(num_weights);
1492   int64_t o = 0;
1493   for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) {
1494     counts_uniq.push_back(counts[indices_data[i]]);
1495     if (o > 0) {
1496       counts_uniq[o] += counts_uniq[o - 1];
1497     }
1498     o++;
1499   }
1500   return counts_uniq;
1501 }
1502 
1503 template <typename scalar_t>
_embedding_bag_dense_backward_cpu_sum_mean(const Tensor & grad,const Tensor & indices_,const Tensor & offset2bag__,const Tensor & bag_size_,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const Tensor & per_sample_weights_,Tensor & index_grad_weight,int64_t padding_idx)1504 void _embedding_bag_dense_backward_cpu_sum_mean(
1505     const Tensor& grad,
1506     const Tensor& indices_,
1507     const Tensor& offset2bag__,
1508     const Tensor& bag_size_,
1509     int64_t num_weights,
1510     bool scale_grad_by_freq,
1511     int64_t mode,
1512     const Tensor& per_sample_weights_,
1513     Tensor& index_grad_weight,
1514     int64_t padding_idx) {
1515 
1516   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1517   Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__);
1518 
1519   auto ind_sort_ = indices_.sort();
1520   auto indices = std::get<0>(ind_sort_);
1521   auto ind_sort = std::get<1>(ind_sort_);
1522   auto offset2bag = offset2bag_.index_select(0, ind_sort);
1523 
1524   std::optional<Tensor> per_sample_weights;
1525   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1526   const scalar_t* per_sample_weights_data;
1527   std::optional<int64_t> per_sample_weights_stride;
1528   if (per_sample_weights_.defined()) {
1529     per_sample_weights = per_sample_weights_.index_select(0, ind_sort);
1530     per_sample_weights_data = per_sample_weights->const_data_ptr<scalar_t>();
1531     per_sample_weights_stride = per_sample_weights->strides()[0];
1532   }
1533 
1534   int64_t numel = indices.numel();
1535 
1536   // explicitly capture all required variables to work around windows build
1537   // TODO: fix this when windows can correctly capture variables in nested lambda
1538   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean",
1539     [&indices, &offset2bag, &bag_size_, &num_weights, &numel, &per_sample_weights,
1540       &per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq,
1541       &grad, &index_grad_weight, &padding_idx] {
1542     auto* indices_data = indices.const_data_ptr<index_t>();
1543     auto* offset2bag_data = offset2bag.const_data_ptr<index_t>();
1544     auto* bag_size_data = bag_size_.const_data_ptr<index_t>();
1545 
1546     auto counts = compute_counts(num_weights, indices_data, numel);
1547     auto next_unique_index_idx =
1548         compute_counts_uniq(num_weights, indices_data, numel, counts);
1549 
1550     auto loop =
1551       [&next_unique_index_idx, &indices_data, &offset2bag_data, &bag_size_data, &per_sample_weights,
1552         &mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq,
1553         &counts, &grad, &index_grad_weight, &padding_idx
1554       ](index_t start, index_t end) {
1555       for (index_t i = start; i < end; i++) {
1556         index_t start = i == 0 ? 0 : next_unique_index_idx[i - 1];
1557         index_t index = indices_data[start];
1558 
1559         if (index != static_cast<index_t>(padding_idx)) {
1560           for (index_t j = start; j < next_unique_index_idx[i]; j++) {
1561             index_t source = offset2bag_data[j];
1562             double scale = 1.0;
1563             if (per_sample_weights) {
1564               AT_ASSERT(mode == EmbeddingBagMode::SUM);
1565               scale = per_sample_weights_data[*per_sample_weights_stride * j];
1566             }
1567             if (scale_grad_by_freq) {
1568               scale /= counts[indices_data[i]];
1569             }
1570             if (mode == EmbeddingBagMode::MEAN) {
1571               auto bag_size = bag_size_data[source];
1572               if (bag_size != 0) {
1573                 scale /= bag_size;
1574               }
1575             }
1576             int64_t ddim = grad.size(1);
1577             auto igwd = index_grad_weight.data_ptr<scalar_t>();
1578             auto gd = grad.const_data_ptr<scalar_t>();
1579             at::native::cpublas::axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1,
1580                         igwd + ddim * index, 1);
1581           }
1582         }
1583       }
1584     };
1585 
1586     if (numel > 1000) {
1587       at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
1588     } else {
1589       loop(0, (int64_t)next_unique_index_idx.size());
1590     }
1591   });
1592 }
1593 
_embedding_bag_dense_backward_cpu(const Tensor & grad_,const Tensor & indices_,const Tensor & offset2bag__,const Tensor & bag_size_,const Tensor & max_indices_,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const std::optional<Tensor> & per_sample_weights__opt,int64_t padding_idx)1594 Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_,
1595                                   const Tensor &offset2bag__,
1596                                   const Tensor &bag_size_,
1597                                   const Tensor& max_indices_, int64_t num_weights,
1598                                   bool scale_grad_by_freq, int64_t mode, const std::optional<Tensor>& per_sample_weights__opt,
1599                                   int64_t padding_idx) {
1600   // See [Note: hacky wrapper removal for optional tensor]
1601   c10::MaybeOwned<Tensor> per_sample_weights__maybe_owned = at::borrow_from_optional_tensor(per_sample_weights__opt);
1602   const Tensor& per_sample_weights_ = *per_sample_weights__maybe_owned;
1603 
1604   // indices_, offsets_ and offset2bag__ are assumed having correct dtypes and
1605   // contiguous here due to the checks in _embedding_bag_backward above.
1606   // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
1607   // for more details.
1608   auto grad = grad_.contiguous();
1609   auto grad_arg = TensorArg(grad, "grad_", 1);
1610   checkScalarTypes(
1611       "embedding_bag", grad_arg, {kHalf, kBFloat16, kFloat, kDouble});
1612 
1613   if (mode == EmbeddingBagMode::MAX) {
1614     return _embedding_bag_dense_backward_cpu_max(
1615         grad_, bag_size_, max_indices_, num_weights);
1616   }
1617   AT_ASSERT(mode == EmbeddingBagMode::MEAN || mode == EmbeddingBagMode::SUM);
1618 
1619   auto index_grad_weight =
1620       at::zeros({num_weights, grad.sizes()[1]}, grad.options());
1621 
1622   AT_DISPATCH_FLOATING_TYPES_AND2(
1623       at::ScalarType::Half,
1624       at::ScalarType::BFloat16,
1625       grad.scalar_type(),
1626       "embedding_bag_backward",
1627       [&] {
1628         _embedding_bag_dense_backward_cpu_sum_mean<scalar_t>(
1629             grad,
1630             indices_,
1631             offset2bag__,
1632             bag_size_,
1633             num_weights,
1634             scale_grad_by_freq,
1635             mode,
1636             per_sample_weights_,
1637             index_grad_weight,
1638             padding_idx);
1639       });
1640   return index_grad_weight;
1641 }
1642 
1643 template<typename scalar_t>
_embedding_bag_per_sample_weights_backward_cpu_template(const Tensor & grad,const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,int64_t mode,int64_t padding_idx)1644 Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
1645     const Tensor& grad,
1646     const Tensor& weight,  // NB: embedding table, not per_sample_weights
1647     const Tensor& indices_,
1648     const Tensor& offsets_,
1649     const Tensor& offset2bag,
1650     int64_t mode,
1651     int64_t padding_idx) {
1652   TORCH_CHECK(
1653       mode == EmbeddingBagMode::SUM,
1654       "embedding_bag_backward: per_sample_weights only supported for mode='sum'");
1655 
1656   AT_ASSERT(grad.dim() == 2);
1657   auto embedding_features = grad.sizes()[1];
1658 
1659   auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1660   const auto& indices = *indicesMaybeOwned;
1661   const auto& offsets = *offsetsMaybeOwned;
1662 
1663   AT_ASSERT(indices.dim() == 1);
1664   auto num_samples = indices.size(0);
1665 
1666   AT_ASSERT(weight.dim() == 2);
1667   AT_ASSERT(weight.sizes()[1] == embedding_features);
1668 
1669   auto output = at::zeros({num_samples}, grad.options());
1670 
1671   auto indices_arg = TensorArg(indices, "indices", 1);
1672   checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
1673   checkContiguous("embedding_bag", indices_arg);
1674 
1675   Tensor offset2bag_;
1676   if (indices.numel() != 0 && offset2bag.numel() == 0) {
1677     offset2bag_ = at::zeros(
1678        {indices.size(0) + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0]
1679 
1680     make_offset2bag(offsets, offset2bag_);
1681 
1682     at::native::resize_(offset2bag_, {indices.size(0)}, std::nullopt);
1683   } else {
1684     auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
1685     checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
1686     checkContiguous("embedding_bag", offset2bag_arg);
1687     offset2bag_ = offset2bag;
1688   }
1689 
1690   auto* grad_data = grad.const_data_ptr<scalar_t>();
1691   auto grad_stride0 = grad.strides()[0];
1692   auto grad_stride1 = grad.strides()[1];
1693 
1694   auto* weight_data = weight.const_data_ptr<scalar_t>();
1695   auto weight_stride0 = weight.strides()[0];
1696   auto weight_stride1 = weight.strides()[1];
1697 
1698   // explicitly capture all required variables to work around windows build
1699   // TODO: fix this when windows can correctly capture variables in nested lambda
1700   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template",
1701     [&indices, &output, &offset2bag_, &num_samples, &embedding_features,
1702       &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1,
1703       &padding_idx] () {
1704     auto* indices_data = indices.const_data_ptr<index_t>();
1705 
1706     // The following are contiguous
1707     auto* output_data = output.data_ptr<scalar_t>();
1708     auto* offset2bag_data = offset2bag_.const_data_ptr<index_t>();
1709 
1710     // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
1711     parallel_for(0, num_samples, 64,
1712       [&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0,
1713         &weight_stride1, &offset2bag_data, &indices_data, &output_data, &padding_idx](index_t begin, index_t end) {
1714       for (index_t sample_idx = begin; sample_idx < end; sample_idx++) {
1715         auto bag_idx = offset2bag_data[sample_idx];
1716         auto embedding_idx = indices_data[sample_idx];
1717 
1718         if (embedding_idx != static_cast<index_t>(padding_idx)) {
1719           output_data[sample_idx] = dot_impl<scalar_t>(
1720               embedding_features,
1721               const_cast<scalar_t*>(grad_data + grad_stride0 * bag_idx), grad_stride1,
1722               const_cast<scalar_t*>(weight_data + weight_stride0 * embedding_idx), weight_stride1);
1723         }
1724       }
1725     });
1726   });
1727   return output;
1728 }
1729 
_embedding_bag_per_sample_weights_backward_cpu(const Tensor & grad,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const Tensor & offset2bag,int64_t mode,int64_t padding_idx)1730 Tensor _embedding_bag_per_sample_weights_backward_cpu(
1731     const Tensor& grad,
1732     const Tensor& weight,  // NB: embedding table, not per_sample_weights
1733     const Tensor& indices,
1734     const Tensor& offsets,
1735     const Tensor& offset2bag,
1736     int64_t mode,
1737     int64_t padding_idx) {
1738   return AT_DISPATCH_FLOATING_TYPES_AND2(
1739       at::ScalarType::Half,
1740       at::ScalarType::BFloat16,
1741       grad.scalar_type(),
1742       "_embedding_bag_per_sample_weights_backward_cpu",
1743       [&]() {
1744         return _embedding_bag_per_sample_weights_backward_cpu_template<
1745             scalar_t>(
1746             grad, weight, indices, offsets, offset2bag, mode, padding_idx);
1747       });
1748 }
1749 
_embedding_bag_sparse_backward_symint(const Tensor & grad_,const Tensor & indices,const Tensor & offsets,const Tensor & offset2bag,const Tensor & bag_size_,SymInt num_weights,bool scale_grad_by_freq,int64_t mode,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)1750 Tensor _embedding_bag_sparse_backward_symint(
1751     const Tensor &grad_, const Tensor &indices, const Tensor &offsets,
1752     const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights,
1753     bool scale_grad_by_freq, int64_t mode, const std::optional<Tensor>& per_sample_weights_opt,
1754     int64_t padding_idx) {
1755   // See [Note: hacky wrapper removal for optional tensor]
1756   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1757   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1758 
1759   // indices, offsets and offset2bag are assumed having correct dtypes and
1760   // contiguous here due to the checks in _embedding_bag_backward above.
1761   // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
1762   // for more details.
1763 
1764   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
1765   Tensor grad = grad_;
1766   Tensor index_grad = grad_.index_select(0, offset2bag);
1767 
1768   index_grad = apply_bag_size_backward(mode, index_grad, offset2bag, bag_size_);
1769 
1770   if (per_sample_weights.defined()) {
1771     AT_ASSERT(mode == EmbeddingBagMode::SUM);
1772     index_grad.mul_(per_sample_weights.unsqueeze(1));
1773   }
1774   return native::embedding_backward_symint(index_grad, indices, std::move(num_weights), padding_idx,
1775                                     scale_grad_by_freq, true);
1776 }
1777 } // namespace at::native
1778