1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/TensorSubclassLikeUtils.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/EmbeddingBag.h>
9
10 #include <ATen/native/CPUBlas.h>
11 #include <ATen/native/NonSymbolicBC.h>
12
13 #include <c10/util/irange.h>
14 #include <c10/util/Half.h>
15
16 #ifdef USE_FBGEMM
17 #include <fbgemm/Fbgemm.h>
18 #include <fbgemm/FbgemmConvert.h>
19 #else
20 #include <caffe2/perfkernels/embedding_lookup_idx.h>
21 #endif
22
23 #include <cstring>
24 #include <tuple>
25 #include <utility>
26 #include <vector>
27
28 #ifndef AT_PER_OPERATOR_HEADERS
29 #include <ATen/Functions.h>
30 #include <ATen/NativeFunctions.h>
31 #else
32 #include <ATen/ops/_embedding_bag.h>
33 #include <ATen/ops/_embedding_bag_backward_native.h>
34 #include <ATen/ops/_embedding_bag_dense_backward.h>
35 #include <ATen/ops/_embedding_bag_dense_backward_native.h>
36 #include <ATen/ops/_embedding_bag_forward_only.h>
37 #include <ATen/ops/_embedding_bag_forward_only_native.h>
38 #include <ATen/ops/_embedding_bag_native.h>
39 #include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
40 #include <ATen/ops/_embedding_bag_sparse_backward.h>
41 #include <ATen/ops/_embedding_bag_sparse_backward_native.h>
42 #include <ATen/ops/embedding_backward_native.h>
43 #include <ATen/ops/embedding_bag_native.h>
44 #include <ATen/ops/empty.h>
45 #include <ATen/ops/max.h>
46 #include <ATen/ops/ones_like.h>
47 #include <ATen/ops/resize_native.h>
48 #include <ATen/ops/zero_native.h>
49 #include <ATen/ops/zeros.h>
50 #endif
51
52 namespace at::native {
53
54 template<typename scalar_t>
55 scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
56
make_offset2bag(const Tensor & offsets,Tensor & offset2bag)57 static void make_offset2bag(const Tensor &offsets, Tensor& offset2bag) {
58 offset2bag.index_add_(
59 0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1]
60 offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
61 offset2bag = offset2bag.cumsum(0, offset2bag.scalar_type()); // offset2bag = [0 0 1 1 2]
62 }
63
64 namespace {
65
promoteIndicesAndOffsets(const Tensor & indices,const Tensor & offsets)66 std::pair<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> promoteIndicesAndOffsets(
67 const Tensor& indices,
68 const Tensor& offsets) {
69 const auto commonType =
70 promoteTypes(offsets.scalar_type(), indices.scalar_type());
71 return {
72 indices.scalar_type() == commonType ? c10::MaybeOwned<Tensor>::borrowed(indices)
73 : c10::MaybeOwned<Tensor>::owned(indices.toType(commonType)),
74 offsets.scalar_type() == commonType ? c10::MaybeOwned<Tensor>::borrowed(offsets)
75 : c10::MaybeOwned<Tensor>::owned(offsets.toType(commonType))};
76 }
77
78 // Determines if we can use a fast implementation for index_select_add, which
79 // is only applicable if special conditions are met
80 template<typename index_t>
is_fast_path_index_select(const Tensor & src,Tensor & output,index_t padding_idx)81 bool is_fast_path_index_select(const Tensor& src, Tensor& output, index_t padding_idx) {
82 return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
83 src.scalar_type() == kBFloat16) &&
84 src.strides()[1] == 1 && output.strides()[1] == 1 &&
85 padding_idx < static_cast<index_t>(0);
86 }
87
88 // Determines if we can use a fast implementation for index_select_scale_add,
89 // which is only applicable if special conditions are met
90 template<typename index_t>
is_fast_path_index_select_scale(const Tensor & src,const Tensor & scale,Tensor & output,index_t padding_idx)91 bool is_fast_path_index_select_scale(const Tensor& src, const Tensor& scale, Tensor& output, index_t padding_idx) {
92 return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
93 src.scalar_type() == kBFloat16) &&
94 src.strides()[1] == 1 && output.strides()[1] == 1 &&
95 scale.strides()[0] == 1 && padding_idx < static_cast<index_t>(0);
96 }
97
98 template<typename index_t>
is_fast_path(const Tensor & src,const std::optional<Tensor> & scale,Tensor & output,index_t padding_idx)99 bool is_fast_path(const Tensor& src, const std::optional<Tensor>& scale, Tensor& output, index_t padding_idx) {
100 return (scale.has_value() && scale.value().defined()) ?
101 is_fast_path_index_select_scale(src, scale.value(), output, padding_idx) :
102 is_fast_path_index_select(src, output, padding_idx);
103 }
104
105 // This function combines index_select (using select_indices as the index) and
106 // index_add (using add_indices as the index), without creating an intermediary
107 // tensor to hold the selected embeddings
108 template <typename data_t, typename index_t>
109 static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
index_select_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & src,Tensor & output,const Tensor &,bool,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache *)110 index_select_add(
111 const Tensor& select_indices,
112 const Tensor& add_indices,
113 const Tensor& src,
114 Tensor& output,
115 const Tensor& /*offsets*/,
116 bool /*include_last_offset*/,
117 Tensor& bag_size,
118 index_t padding_idx,
119 _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
120 TORCH_CHECK(select_indices.numel() == add_indices.numel());
121 auto* add_indices_data = add_indices.const_data_ptr<index_t>();
122 auto* select_indices_data = select_indices.const_data_ptr<index_t>();
123 auto* src_data = src.const_data_ptr<data_t>();
124 auto* output_data = output.data_ptr<data_t>();
125 index_t* bag_size_data = nullptr;
126 if (bag_size.defined()) {
127 bag_size_data = bag_size.data_ptr<index_t>();
128 }
129 auto numel = add_indices.numel();
130 int64_t ddim = src.size(1);
131 auto vocab_size = src.size(0);
132 auto src_stride0 = src.strides()[0];
133 auto src_stride1 = src.strides()[1];
134 auto output_stride0 = output.strides()[0];
135 auto output_stride1 = output.strides()[1];
136
137 for (const auto i : c10::irange(numel)) {
138 // We can skip indices equal to padding_idx so they are not included in
139 // the reduction
140 auto idx = select_indices_data[i];
141 TORCH_CHECK(
142 idx >= 0 && idx < vocab_size,
143 "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
144 idx);
145 if (idx != padding_idx) {
146 at::native::cpublas::axpy<data_t>(ddim, 1,
147 src_data + src_stride0 * idx, src_stride1,
148 output_data + output_stride0 * add_indices_data[i], output_stride1);
149 } else if (bag_size_data) {
150 // Decrement bag_size to reflect that the index is padded
151 bag_size_data[add_indices_data[i]]--;
152 }
153 }
154 }
155
156 namespace {
157 template <typename index_t>
fbgemm_spmdm_report_error_(int64_t output_size,int index_size,int64_t N,const index_t * offsets,const index_t * indices)158 void fbgemm_spmdm_report_error_(
159 int64_t output_size,
160 int index_size,
161 int64_t N,
162 const index_t* offsets,
163 const index_t* indices) {
164 for (const auto m : c10::irange(output_size)) {
165 for (index_t i = offsets[m]; i < offsets[m + 1]; ++i) {
166 TORCH_CHECK(i < index_size);
167 index_t idx = indices[i];
168 TORCH_CHECK(
169 0 <= idx && idx < N,
170 "Index ",
171 i,
172 " of input takes value ",
173 idx,
174 " which is not in the valid range [0, ",
175 N,
176 ")");
177 }
178 }
179 TORCH_CHECK(
180 offsets[output_size] == index_size,
181 "Your input appears to be incorrect: the last offset value should be "
182 "the size of the indices tensor, but it seems not to be the case.");
183 }
184 } // namespace
185
186 template <typename data_t, typename index_t>
187 typename std::enable_if<
188 std::is_same<data_t, at::Half>::value ||
189 std::is_same<data_t, at::BFloat16>::value,
190 void>::type
index_select_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)191 index_select_add(
192 const Tensor& select_indices,
193 const Tensor& add_indices,
194 const Tensor& src,
195 Tensor& output,
196 const Tensor& offsets,
197 bool include_last_offset,
198 Tensor& bag_size,
199 index_t padding_idx,
200 _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
201 int64_t ddim = src.size(1);
202 auto* select_indices_data = select_indices.const_data_ptr<index_t>();
203 auto* output_data = output.data_ptr<data_t>();
204
205 if (is_fast_path_index_select(src, output, padding_idx)) {
206 auto src_contig = src.contiguous();
207 auto* src_data = src_contig.const_data_ptr<data_t>();
208 int64_t output_size = offsets.numel() - 1;
209 auto* offsets_data = offsets.const_data_ptr<index_t>();
210 std::vector<index_t> offsets_include_last;
211
212 if (include_last_offset) {
213 output_size = offsets.numel() - 1;
214 } else {
215 output_size = offsets.numel();
216 offsets_include_last.resize(offsets.numel() + 1);
217 if (offsets.numel() > 0) {
218 std::memcpy(
219 offsets_include_last.data(),
220 offsets.const_data_ptr<index_t>(),
221 sizeof(index_t) * offsets.numel());
222 }
223 offsets_include_last[offsets.numel()] = select_indices.numel();
224 offsets_data = offsets_include_last.data();
225 }
226 #if defined(USE_FBGEMM)
227 constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
228 auto kernel_16bit_index_t = fbgemm_kernel_cache
229 ? fbgemm_kernel_cache
230 ->getCallback</* has_weight */ false, index_t, uint16_t>(ddim)
231 : fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
232 /* block_size */ ddim,
233 /* has_weight */ false,
234 /* normalize_by_lengths */ false,
235 /* prefetch */ 16,
236 /* is_weight_positional */ false,
237 /* use_offsets */ true,
238 /* is_bf16_out */ isbf16,
239 /* is_bf16_in */ isbf16);
240 at::parallel_for(
241 0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
242 bool success = kernel_16bit_index_t(
243 /* output_size */ end_idx - start_idx,
244 /* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
245 /* data_size */ src.size(0),
246 /* input */ reinterpret_cast<const uint16_t*>(src_data),
247 /* indices */ select_indices_data + offsets_data[start_idx],
248 /* offsets_or_lengths */ offsets_data + start_idx,
249 /* weights */ nullptr,
250 /* output */
251 reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
252 if (!success) {
253 fbgemm_spmdm_report_error_(
254 end_idx - start_idx,
255 offsets_data[end_idx] - offsets_data[start_idx],
256 src.size(0),
257 offsets_data + start_idx,
258 select_indices_data + offsets_data[start_idx]);
259 }
260 });
261 #else
262 // Initialize the intermediate output buffer to be 0.
263 Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
264 auto* output_data_fp32 = output_fp32.data_ptr<float>();
265 using bVec = vec::Vectorized<BFloat16>;
266 using fVec = vec::Vectorized<float>;
267 at::parallel_for(
268 0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
269 caffe2::EmbeddingLookupIdx(
270 /*block_size=*/ddim,
271 /*output_size=*/end_idx - start_idx,
272 /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
273 /*data_size=*/src.size(0),
274 /*input=*/src_data,
275 /*indices=*/select_indices_data + offsets_data[start_idx],
276 /*offsets=*/offsets_data + start_idx,
277 /*weights=*/nullptr,
278 /*scale_bias=*/nullptr,
279 /*normalize_by_lengths=*/false,
280 /*out=*/output_data_fp32 + start_idx * ddim);
281 for (int64_t i = start_idx; i < end_idx; i++) {
282 // Convert FP32 intermediate buffer result back to 16 bit for
283 // output dtype
284 if constexpr (std::is_same<data_t, at::Half>::value) {
285 // FP16
286 for (const auto d : c10::irange(ddim)) {
287 (output_data + i * ddim)[d] =
288 static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
289 }
290 } else {
291 // BF16
292 int64_t d = 0;
293 for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
294 fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
295 fVec temp_fp32_1 =
296 fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
297 convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
298 .store(output_data + i * ddim + d);
299 }
300 for (; d < ddim; d++) {
301 (output_data + i * ddim)[d] =
302 static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
303 }
304 }
305 }
306 });
307 #endif
308 } else {
309 TORCH_CHECK(select_indices.numel() == add_indices.numel());
310 auto* src_data = src.const_data_ptr<data_t>();
311 auto* add_indices_data = add_indices.const_data_ptr<index_t>();
312 index_t* bag_size_data = nullptr;
313 if (bag_size.defined()) {
314 bag_size_data = bag_size.data_ptr<index_t>();
315 }
316 auto vocab_size = src.size(0);
317 auto src_stride0 = src.strides()[0];
318 auto src_stride1 = src.strides()[1];
319 auto output_stride0 = output.strides()[0];
320 auto output_stride1 = output.strides()[1];
321 auto numel = add_indices.numel();
322
323 Tensor src_fp32 = at::empty({ddim}, src.options().dtype(at::kFloat));
324 auto* src_data_fp32 = src_fp32.mutable_data_ptr<float>();
325
326 // Initialize the intermediate output buffer to be 0.
327 Tensor output_fp32 =
328 at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
329 auto* output_data_fp32 = output_fp32.data_ptr<float>();
330
331 for (const auto i : c10::irange(numel)) {
332 // We can skip indices equal to padding_idx so they are not included in
333 // the reduction
334 auto idx = select_indices_data[i];
335 TORCH_CHECK(
336 idx >= 0 && idx < vocab_size,
337 "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
338 idx);
339 if (idx != padding_idx) {
340 // Copy src_data + src_stride0 * idx to src_data_fp32
341 for (const auto d : c10::irange(ddim)) {
342 src_data_fp32[d] = static_cast<float>(
343 (src_data + src_stride0 * idx)[d * src_stride1]);
344 }
345 at::native::cpublas::axpy<float>(
346 ddim,
347 1,
348 src_data_fp32,
349 1,
350 output_data_fp32 + ddim * add_indices_data[i],
351 1);
352
353 } else if (bag_size_data) {
354 // Decrement bag_size to reflect that the index is padded
355 bag_size_data[add_indices_data[i]]--;
356 }
357 }
358 for (const auto i : c10::irange(output.size(0))) {
359 // Convert FP32 intermediate buffer result back to 16 bit for output
360 // dtype
361 for (const auto d : c10::irange(ddim)) {
362 (output_data + output_stride0 * i)[d * output_stride1] =
363 static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
364 }
365 }
366 }
367 }
368 template<typename data_t, typename index_t>
369 typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)370 index_select_add(const Tensor &select_indices,
371 const Tensor &add_indices,
372 const Tensor &src,
373 Tensor &output,
374 const Tensor& offsets,
375 bool include_last_offset,
376 Tensor &bag_size,
377 index_t padding_idx,
378 _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
379 int64_t ddim = src.size(1);
380 auto* select_indices_data = select_indices.const_data_ptr<index_t>();
381 auto* output_data = output.data_ptr<float>();
382
383 if (is_fast_path_index_select(src, output, padding_idx)) {
384 auto src_contig = src.contiguous();
385 auto* src_data = src_contig.const_data_ptr<float>();
386 int64_t output_size = offsets.numel() - 1;
387 auto* offsets_data = offsets.const_data_ptr<index_t>();
388 std::vector<index_t> offsets_include_last;
389
390 if (include_last_offset) {
391 output_size = offsets.numel() - 1;
392 } else {
393 output_size = offsets.numel();
394 offsets_include_last.resize(offsets.numel() + 1);
395 if (offsets.numel() > 0) {
396 std::memcpy(
397 offsets_include_last.data(),
398 offsets.const_data_ptr<index_t>(),
399 sizeof(index_t) * offsets.numel());
400 }
401 offsets_include_last[offsets.numel()] = select_indices.numel();
402 offsets_data = offsets_include_last.data();
403 }
404
405 #ifdef USE_FBGEMM
406 auto kernel_fp32_index_t =
407 fbgemm_kernel_cache ?
408 fbgemm_kernel_cache->getCallback</* has_weight */ false, index_t, float>(ddim) :
409 fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
410 /* block_size */ddim,
411 /* has_weight */false,
412 /* normalize_by_lengths */false,
413 /* prefetch */16,
414 /* is_weight_positional */false,
415 /* use_offsets */true
416 );
417 #endif
418 at::parallel_for(
419 0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
420 #ifdef USE_FBGEMM
421 bool success = kernel_fp32_index_t(
422 /* output_size */end_idx - start_idx,
423 /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
424 /* data_size */src.size(0),
425 /* input */src_data,
426 /* indices */select_indices_data + offsets_data[start_idx],
427 /* offsets_or_lengths */offsets_data + start_idx,
428 /* weights */nullptr,
429 /* output */output_data + start_idx * ddim);
430 if (!success) {
431 fbgemm_spmdm_report_error_(
432 end_idx - start_idx,
433 offsets_data[end_idx] - offsets_data[start_idx],
434 src.size(0),
435 offsets_data + start_idx,
436 select_indices_data + offsets_data[start_idx]);
437 }
438 #else
439 caffe2::EmbeddingLookupIdx(
440 /*block_size=*/ddim,
441 /*output_size=*/end_idx - start_idx,
442 /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
443 /*data_size=*/src.size(0),
444 /*input=*/src_data,
445 /*indices=*/select_indices_data + offsets_data[start_idx],
446 /*offsets=*/offsets_data + start_idx,
447 /*weights=*/nullptr,
448 /*scale_bias=*/nullptr,
449 /*normalize_by_lengths=*/false,
450 /*out=*/output_data + start_idx * ddim);
451 #endif
452 });
453 } else {
454 AT_ASSERT(select_indices.numel() == add_indices.numel());
455 auto* src_data = src.const_data_ptr<float>();
456 auto* add_indices_data = add_indices.const_data_ptr<index_t>();
457 index_t* bag_size_data = nullptr;
458 if (bag_size.defined()) {
459 bag_size_data = bag_size.data_ptr<index_t>();
460 }
461 auto vocab_size = src.size(0);
462 auto src_stride0 = src.strides()[0];
463 auto src_stride1 = src.strides()[1];
464 auto output_stride0 = output.strides()[0];
465 auto output_stride1 = output.strides()[1];
466 auto numel = add_indices.numel();
467 for (const auto i : c10::irange(numel)) {
468 // We can skip indices equal to padding_idx so they are not included in
469 // the reduction
470 auto idx = select_indices_data[i];
471 TORCH_CHECK(
472 idx >= 0 && idx < vocab_size,
473 "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
474 idx);
475 if (idx != padding_idx) {
476 at::native::cpublas::axpy<float>(
477 ddim,
478 1,
479 src_data + src_stride0 * idx,
480 src_stride1,
481 output_data + output_stride0 * add_indices_data[i],
482 output_stride1);
483 } else if (bag_size_data) {
484 // Decrement bag_size to reflect that the index is padded
485 bag_size_data[add_indices_data[i]]--;
486 }
487 }
488 }
489 }
490
491 // This function fuses the following three fns:
492 // index_select (using select_indices as the index)
493 // mul (scaling by per_sample_weights)
494 // index_add (using add_indices as the index)
495 template <typename data_t, typename index_t>
496 static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
index_select_scale_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & scale,const Tensor & src,Tensor & output,const Tensor &,bool,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache *)497 index_select_scale_add(
498 const Tensor& select_indices,
499 const Tensor& add_indices,
500 const Tensor& scale,
501 const Tensor& src,
502 Tensor& output,
503 const Tensor& /*offsets*/,
504 bool /*include_last_offset*/,
505 Tensor& bag_size,
506 index_t padding_idx,
507 _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
508 AT_ASSERT(select_indices.numel() == add_indices.numel());
509 auto* add_indices_data = add_indices.const_data_ptr<index_t>();
510 auto* select_indices_data = select_indices.const_data_ptr<index_t>();
511 auto* src_data = src.const_data_ptr<data_t>();
512 auto* output_data = output.data_ptr<data_t>();
513 index_t* bag_size_data = nullptr;
514 if (bag_size.defined()) {
515 bag_size_data = bag_size.data_ptr<index_t>();
516 }
517 auto numel = add_indices.numel();
518 int64_t ddim = src.size(1);
519 auto vocab_size = src.size(0);
520 auto src_stride0 = src.strides()[0];
521 auto src_stride1 = src.strides()[1];
522 auto output_stride0 = output.strides()[0];
523 auto output_stride1 = output.strides()[1];
524
525 auto* scale_data = scale.const_data_ptr<data_t>();
526 auto scale_stride = scale.strides()[0];
527
528 for (const auto i : c10::irange(numel)) {
529 // We can skip indices equal to padding_idx so they are not included in
530 // the reduction
531 auto idx = select_indices_data[i];
532 TORCH_CHECK(
533 idx >= 0 && idx < vocab_size,
534 "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
535 idx);
536 if (idx != padding_idx) {
537 auto* src_base = src_data + src_stride0 * idx;
538 auto* output_base = output_data + output_stride0 * add_indices_data[i];
539 auto scale = scale_data[i * scale_stride];
540 for (const auto j : c10::irange(ddim)) {
541 output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
542 }
543 } else if (bag_size_data) {
544 // Decrement bag_size to reflect that the index is padded
545 bag_size_data[add_indices_data[i]]--;
546 }
547 }
548 }
549
550 template <typename data_t, typename index_t>
551 typename std::enable_if<
552 std::is_same<data_t, at::Half>::value ||
553 std::is_same<data_t, at::BFloat16>::value,
554 void>::type
index_select_scale_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & scale,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)555 index_select_scale_add(
556 const Tensor& select_indices,
557 const Tensor& add_indices,
558 const Tensor& scale,
559 const Tensor& src,
560 Tensor& output,
561 const Tensor& offsets,
562 bool include_last_offset,
563 Tensor& bag_size,
564 index_t padding_idx,
565 _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
566 int64_t ddim = src.size(1);
567 auto* scale_data = scale.const_data_ptr<data_t>();
568 auto* select_indices_data = select_indices.const_data_ptr<index_t>();
569 auto* output_data = output.data_ptr<data_t>();
570
571 if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
572 auto src_contig = src.contiguous();
573 auto* src_data = src_contig.const_data_ptr<data_t>();
574 int64_t output_size = offsets.numel() - 1;
575 auto* offsets_data = offsets.const_data_ptr<index_t>();
576 std::vector<index_t> offsets_include_last;
577
578 if (include_last_offset) {
579 output_size = offsets.numel() - 1;
580 } else {
581 output_size = offsets.numel();
582 offsets_include_last.resize(offsets.numel() + 1);
583 std::memcpy(
584 offsets_include_last.data(),
585 offsets.const_data_ptr<index_t>(),
586 sizeof(index_t) * offsets.numel());
587 offsets_include_last[offsets.numel()] = select_indices.numel();
588 offsets_data = offsets_include_last.data();
589 }
590
591 Tensor scale_fp32 = at::empty(scale.sizes(), scale.options().dtype(at::kFloat));
592 auto* scale_data_fp32 = scale_fp32.mutable_data_ptr<float>();
593
594 #if defined(USE_FBGEMM)
595 constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
596 if constexpr (isbf16) {
597 fbgemm::Bfloat16ToFloat_simd(
598 reinterpret_cast<const fbgemm::bfloat16*>(scale_data),
599 scale_data_fp32,
600 scale_fp32.numel());
601 } else {
602 fbgemm::Float16ToFloat_simd(
603 reinterpret_cast<const fbgemm::float16*>(scale_data),
604 scale_data_fp32,
605 scale_fp32.numel());
606 }
607 auto kernel_16bit_index_t = fbgemm_kernel_cache
608 ? fbgemm_kernel_cache
609 ->getCallback</* has_weight */ true, index_t, uint16_t>(ddim)
610 : fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
611 /* block_size */ ddim,
612 /* has_weight */ true,
613 /* normalize_by_lengths */ false,
614 /* prefetch */ 16,
615 /* is_weight_positional */ false,
616 /* use_offsets */ true,
617 /* is_bf16_out */ isbf16,
618 /* is_bf16_in */ isbf16);
619 at::parallel_for(
620 0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
621 bool success = kernel_16bit_index_t(
622 /* output_size */ end_idx - start_idx,
623 /* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
624 /* data_size */ src.size(0),
625 /* input */ reinterpret_cast<const uint16_t*>(src_data),
626 /* indices */ select_indices_data + offsets_data[start_idx],
627 /* offsets_or_lengths */ offsets_data + start_idx,
628 /* weights */ scale_data_fp32 + offsets_data[start_idx],
629 /* output */
630 reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
631 if (!success) {
632 fbgemm_spmdm_report_error_(
633 end_idx - start_idx,
634 offsets_data[end_idx] - offsets_data[start_idx],
635 src.size(0),
636 offsets_data + start_idx,
637 select_indices_data + offsets_data[start_idx]);
638 }
639 });
640 #else
641 // Initialize the intermediate output buffer to be 0.
642 Tensor output_fp32 =
643 at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
644 auto* output_data_fp32 = output_fp32.data_ptr<float>();
645 for (const auto i : c10::irange(scale.numel())) {
646 scale_data_fp32[i] = static_cast<float>(scale_data[i]);
647 }
648 using bVec = vec::Vectorized<BFloat16>;
649 using fVec = vec::Vectorized<float>;
650 at::parallel_for(
651 0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
652 caffe2::EmbeddingLookupIdx(
653 /*block_size=*/ddim,
654 /*output_size=*/end_idx - start_idx,
655 /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
656 /*data_size=*/src.size(0),
657 /*input=*/src_data,
658 /*indices=*/select_indices_data + offsets_data[start_idx],
659 /*offsets=*/offsets_data + start_idx,
660 /*weights=*/scale_data_fp32 + offsets_data[start_idx],
661 /*scale_bias=*/nullptr,
662 /*normalize_by_lengths=*/false,
663 /*out=*/output_data_fp32 + start_idx * ddim);
664 for (int64_t i = start_idx; i < end_idx; i++) {
665 // Convert FP32 intermediate buffer result back to 16 bit for
666 // output dtype
667 if constexpr (std::is_same<data_t, at::Half>::value) {
668 // FP16
669 for (const auto d : c10::irange(ddim)) {
670 (output_data + i * ddim)[d] =
671 static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
672 }
673 } else {
674 // BF16
675 int64_t d = 0;
676 for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
677 fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
678 fVec temp_fp32_1 =
679 fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
680 convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
681 .store(output_data + i * ddim + d);
682 }
683 for (; d < ddim; d++) {
684 (output_data + i * ddim)[d] =
685 static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
686 }
687 }
688 }
689 });
690 #endif
691 } else {
692 AT_ASSERT(select_indices.numel() == add_indices.numel());
693 auto* src_data = src.const_data_ptr<data_t>();
694 auto* add_indices_data = add_indices.const_data_ptr<index_t>();
695 index_t* bag_size_data = nullptr;
696 if (bag_size.defined()) {
697 bag_size_data = bag_size.data_ptr<index_t>();
698 }
699 auto vocab_size = src.size(0);
700 auto src_stride0 = src.strides()[0];
701 auto src_stride1 = src.strides()[1];
702 auto output_stride0 = output.strides()[0];
703 auto output_stride1 = output.strides()[1];
704 auto scale_stride = scale.strides()[0];
705 auto numel = add_indices.numel();
706
707 // Initialize the intermediate output buffer to be 0.
708 Tensor output_fp32 =
709 at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
710 auto* output_data_fp32 = output_fp32.data_ptr<float>();
711
712 for (const auto i : c10::irange(numel)) {
713 // We can skip indices equal to padding_idx so they are not included in
714 // the reduction
715 auto idx = select_indices_data[i];
716 TORCH_CHECK(
717 idx >= 0 && idx < vocab_size,
718 "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
719 idx);
720 if (idx != padding_idx) {
721 auto* src_base = src_data + src_stride0 * idx;
722 auto* output_base_fp32 = output_data_fp32 + ddim * add_indices_data[i];
723 auto scale = scale_data[i * scale_stride];
724 for (const auto j : c10::irange(ddim)) {
725 output_base_fp32[j] += static_cast<float>(src_base[j * src_stride1]) *
726 static_cast<float>(scale);
727 }
728 } else if (bag_size_data) {
729 // Decrement bag_size to reflect that the index is padded
730 bag_size_data[add_indices_data[i]]--;
731 }
732 }
733 for (const auto i : c10::irange(output.size(0))) {
734 // Convert FP32 intermediate buffer result back to 16 bit for output
735 // dtype
736 for (const auto d : c10::irange(ddim)) {
737 (output_data + output_stride0 * i)[d * output_stride1] =
738 static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
739 }
740 }
741 }
742 }
743 template<typename data_t, typename index_t>
744 typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_scale_add(const Tensor & select_indices,const Tensor & add_indices,const Tensor & scale,const Tensor & src,Tensor & output,const Tensor & offsets,bool include_last_offset,Tensor & bag_size,index_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)745 index_select_scale_add(const Tensor &select_indices,
746 const Tensor &add_indices,
747 const Tensor &scale,
748 const Tensor &src,
749 Tensor &output,
750 const Tensor& offsets,
751 bool include_last_offset,
752 Tensor &bag_size,
753 index_t padding_idx,
754 _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
755 int64_t ddim = src.size(1);
756 auto* scale_data = scale.const_data_ptr<float>();
757 auto* select_indices_data = select_indices.const_data_ptr<index_t>();
758 auto* output_data = output.data_ptr<float>();
759
760 if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
761 auto src_contig = src.contiguous();
762 auto* src_data = src_contig.const_data_ptr<float>();
763 int64_t output_size = offsets.numel() - 1;
764 auto* offsets_data = offsets.const_data_ptr<index_t>();
765 std::vector<index_t> offsets_include_last;
766
767 if (include_last_offset) {
768 output_size = offsets.numel() - 1;
769 } else {
770 output_size = offsets.numel();
771 offsets_include_last.resize(offsets.numel() + 1);
772 std::memcpy(
773 offsets_include_last.data(),
774 offsets.const_data_ptr<index_t>(),
775 sizeof(index_t) * offsets.numel());
776 offsets_include_last[offsets.numel()] = select_indices.numel();
777 offsets_data = offsets_include_last.data();
778 }
779
780 #ifdef USE_FBGEMM
781 auto kernel_fp32_index_t =
782 fbgemm_kernel_cache ?
783 fbgemm_kernel_cache->getCallback</* has_weight */ true, index_t, float>(ddim) :
784 fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
785 /* block_size */ddim,
786 /* has_weight */true,
787 /* normalize_by_lengths */false,
788 /* prefetch */16,
789 /* is_weight_positional */false,
790 /* use_offsets */true
791 );
792 #endif
793 at::parallel_for(
794 0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
795 #ifdef USE_FBGEMM
796 bool success = kernel_fp32_index_t(
797 /* output_size */end_idx - start_idx,
798 /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
799 /* data_size */src.size(0),
800 /* input */src_data,
801 /* indices */select_indices_data + offsets_data[start_idx],
802 /* offsets_or_lengths */offsets_data + start_idx,
803 /* weights */scale_data + offsets_data[start_idx],
804 /* output */output_data + start_idx * ddim);
805 if (!success) {
806 fbgemm_spmdm_report_error_(
807 end_idx - start_idx,
808 offsets_data[end_idx] - offsets_data[start_idx],
809 src.size(0),
810 offsets_data + start_idx,
811 select_indices_data + offsets_data[start_idx]);
812 }
813 #else
814 caffe2::EmbeddingLookupIdx(
815 /*block_size=*/ddim,
816 /*output_size=*/end_idx - start_idx,
817 /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
818 /*data_size=*/src.size(0),
819 /*input=*/src_data,
820 /*indices=*/select_indices_data + offsets_data[start_idx],
821 /*offsets=*/offsets_data + start_idx,
822 /*weights=*/scale_data + offsets_data[start_idx],
823 /*scale_bias=*/nullptr,
824 /*normalize_by_lengths=*/false,
825 /*out=*/output_data + start_idx * ddim);
826 #endif
827 });
828 } else {
829 AT_ASSERT(select_indices.numel() == add_indices.numel());
830 auto* src_data = src.const_data_ptr<float>();
831 auto* add_indices_data = add_indices.const_data_ptr<index_t>();
832 index_t* bag_size_data = nullptr;
833 if (bag_size.defined()) {
834 bag_size_data = bag_size.data_ptr<index_t>();
835 }
836 auto vocab_size = src.size(0);
837 auto src_stride0 = src.strides()[0];
838 auto src_stride1 = src.strides()[1];
839 auto output_stride0 = output.strides()[0];
840 auto output_stride1 = output.strides()[1];
841 auto scale_stride = scale.strides()[0];
842 auto numel = add_indices.numel();
843
844
845 for (const auto i : c10::irange(numel)) {
846 // We can skip indices equal to padding_idx so they are not included in
847 // the reduction
848 auto idx = select_indices_data[i];
849 TORCH_CHECK(
850 idx >= 0 && idx < vocab_size,
851 "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
852 idx);
853 if (idx != padding_idx) {
854 auto* src_base = src_data + src_stride0 * idx;
855 auto* output_base = output_data + output_stride0 * add_indices_data[i];
856 auto scale = scale_data[i * scale_stride];
857 for (const auto j : c10::irange(ddim)) {
858 output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
859 }
860 } else if (bag_size_data) {
861 // Decrement bag_size to reflect that the index is padded
862 bag_size_data[add_indices_data[i]]--;
863 }
864 }
865 }
866 }
867
868 } // namespace
869
check_arguments(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,bool include_last_offset)870 void check_arguments(
871 const Tensor& weight,
872 const Tensor& indices,
873 const Tensor& offsets,
874 const int64_t mode,
875 const std::optional<Tensor>& per_sample_weights,
876 bool include_last_offset) {
877 auto indices_arg = TensorArg(indices, "indices", 1);
878 checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
879 auto offsets_arg = TensorArg(offsets, "offsets", 1);
880 checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
881 checkSameType("embedding_bag", indices_arg, offsets_arg);
882 auto weight_arg = TensorArg(weight, "weight", 1);
883 checkScalarTypes(
884 "embedding_bag", weight_arg, {kHalf, kBFloat16, kFloat, kDouble});
885
886 AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() {
887 if (offsets.size(0) > 0) {
888 index_t offset_0 = offsets.const_data_ptr<index_t>()[0];
889 index_t offset_n = offsets.const_data_ptr<index_t>()[offsets.size(0)-1];
890 TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
891 "in the mini-batch has to start from position 0. "
892 "However, got ", offsets[0]);
893 TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
894 "be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
895 offset_n);
896 }
897 });
898
899 if (per_sample_weights.has_value() && per_sample_weights.value().defined()) {
900 TORCH_CHECK(
901 mode == EmbeddingBagMode::SUM,
902 "embedding_bag: per_sample_weights only supported with mode='sum'");
903 auto per_input_weights_arg = TensorArg(
904 per_sample_weights.value(),"per_sample_weights", 1);
905 checkSameType("embedding_bag", weight_arg, per_input_weights_arg);
906 TORCH_CHECK(per_sample_weights.value().dim() == 1);
907 TORCH_CHECK(per_sample_weights.value().numel() == indices.numel());
908 }
909
910 if (include_last_offset) {
911 TORCH_CHECK(
912 offsets.size(0) >= 1,
913 "include_last_offset: number of offset should be at least 1");
914 }
915 }
916
make_bag_size_out(Tensor & bag_size_out,const Tensor & offsets,const Tensor & indices,const int64_t mode,const bool include_last_offset,const bool requires_grad)917 void make_bag_size_out(
918 Tensor& bag_size_out,
919 const Tensor& offsets,
920 const Tensor& indices,
921 const int64_t mode,
922 const bool include_last_offset,
923 const bool requires_grad) {
924 if (requires_grad || mode == EmbeddingBagMode::MEAN ||
925 mode == EmbeddingBagMode::MAX) {
926 auto num_bags = offsets.size(0) - (include_last_offset ? 1 : 0);
927 at::native::resize_(bag_size_out, {num_bags}, std::nullopt);
928 // Compute this for EmbeddingBagMode::MEAN and EmbeddingBagMode::MAX (latter
929 // needed for backwards)
930 if (num_bags != 1) {
931 bag_size_out.slice(0, 0, bag_size_out.size(0) - 1, 1) =
932 offsets.slice(0, 1, num_bags, 1) -
933 offsets.slice(0, 0, num_bags - 1, 1);
934 }
935 if (num_bags > 0) {
936 bag_size_out[-1] = indices.size(0) - offsets[num_bags - 1];
937 }
938 } else {
939 at::native::resize_(bag_size_out, offsets.sizes(), std::nullopt);
940 }
941 }
942
make_max_indices_out(Tensor & max_indices_out,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const Tensor & bag_size,const int64_t mode,bool include_last_offset)943 void make_max_indices_out(
944 Tensor& max_indices_out,
945 const Tensor& weight,
946 const Tensor& indices,
947 const Tensor& offsets,
948 const Tensor& bag_size,
949 const int64_t mode,
950 bool include_last_offset) {
951 int64_t numBags = offsets.size(0);
952 if (mode == EmbeddingBagMode::MAX) {
953 if (include_last_offset) {
954 TORCH_CHECK(
955 numBags >= 1, "include_last_offset: numBags should be at least 1");
956 numBags -= 1;
957 }
958 at::native::resize_(max_indices_out, {numBags, weight.sizes()[1]}, std::nullopt);
959 at::native::zero_(max_indices_out);
960 } else {
961 at::native::resize_(max_indices_out, bag_size.sizes(), std::nullopt);
962 }
963 }
964
make_offset2bag_out(Tensor & offset2bag,Tensor & output,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,const int64_t padding_idx)965 void make_offset2bag_out(
966 Tensor& offset2bag,
967 Tensor& output,
968 const Tensor& weight,
969 const Tensor& indices,
970 const Tensor& offsets,
971 const int64_t mode,
972 const std::optional<Tensor>& per_sample_weights,
973 const int64_t padding_idx) {
974 // To save compute, if we are going to go down the fast path case for the 'sum'
975 // mode, we skip calculating offset2bag, since it is not going to be used.
976 bool fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx);
977
978 if (mode == EmbeddingBagMode::MEAN || mode == EmbeddingBagMode::MAX ||
979 !fast_path_sum) {
980 at::native::resize_(offset2bag, {indices.size(0) + 1}, std::nullopt);
981 at::native::zero_(offset2bag);
982
983 int64_t offsets_size = offsets.size(0);
984 bool include_last_offset = (output.size(0) == offsets_size - 1);
985 // when include_last_offset is true, ignore the last index in offset.
986 // fix segfault when include_last_offset is true and offsets[-1] != indices.size(0)
987 // see https://github.com/pytorch/pytorch/issues/89677 for more details.
988 Tensor _offsets = offsets;
989 if (include_last_offset) {
990 _offsets = offsets.narrow(0, 0, offsets_size - 1);
991 }
992 make_offset2bag(_offsets, offset2bag);
993 at::native::resize_(offset2bag, {indices.size(0)}, std::nullopt);
994 // only initialize output in slow path
995 at::native::zero_(output);
996 }
997 }
998
make_bag_size(const Tensor & offsets,const Tensor & indices,const int64_t mode,const bool include_last_offset,const bool requires_grad)999 static Tensor make_bag_size(
1000 const Tensor& offsets,
1001 const Tensor& indices,
1002 const int64_t mode,
1003 const bool include_last_offset,
1004 const bool requires_grad) {
1005 Tensor bag_size = at::empty(offsets.sizes(), offsets.options());
1006 make_bag_size_out(bag_size, offsets, indices, mode, include_last_offset, requires_grad);
1007 return bag_size;
1008 }
1009
make_max_indices(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const Tensor & bag_size,const int64_t mode,bool include_last_offset)1010 static Tensor make_max_indices(
1011 const Tensor& weight,
1012 const Tensor& indices,
1013 const Tensor& offsets,
1014 const Tensor& bag_size,
1015 const int64_t mode,
1016 bool include_last_offset) {
1017 Tensor max_indices = at::empty(bag_size.sizes(), offsets.options());
1018 make_max_indices_out(max_indices, weight, indices, offsets, bag_size, mode, include_last_offset);
1019 return max_indices;
1020 }
1021
make_offset2bag(Tensor & output,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,const int64_t padding_idx)1022 static Tensor make_offset2bag(
1023 Tensor& output,
1024 const Tensor& weight,
1025 const Tensor& indices,
1026 const Tensor& offsets,
1027 const int64_t mode,
1028 const std::optional<Tensor>& per_sample_weights,
1029 const int64_t padding_idx) {
1030 Tensor offset2bag = at::empty({0}, offsets.options());
1031 make_offset2bag_out(offset2bag, output, weight, indices, offsets, mode, per_sample_weights, padding_idx);
1032 return offset2bag;
1033 }
1034
apply_bag_size(const int64_t mode,Tensor & output,const Tensor & bag_size)1035 static Tensor apply_bag_size(
1036 const int64_t mode,
1037 Tensor &output,
1038 const Tensor &bag_size) {
1039 if (mode == EmbeddingBagMode::MEAN) {
1040 auto bag_size_ = at::max(bag_size, at::ones_like(bag_size, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
1041 .to(output.options())
1042 .unsqueeze(1)
1043 .expand_as(output);
1044 output /= bag_size_;
1045 }
1046 return output;
1047 }
1048
apply_bag_size_backward(const int64_t mode,Tensor & output,const Tensor & offset2bag,const Tensor & bag_size)1049 static Tensor apply_bag_size_backward(
1050 const int64_t mode,
1051 Tensor &output,
1052 const Tensor &offset2bag,
1053 const Tensor &bag_size) {
1054 if (mode == EmbeddingBagMode::MEAN) {
1055 auto inv_bag_size_ = (1 / bag_size.to(output.options()))
1056 .unsqueeze(1)
1057 .index_select(0, offset2bag);
1058 output *= inv_bag_size_;
1059 }
1060 return output;
1061 }
1062
1063 template <typename scalar_t>
embedding_bag_cpu_max_out(Tensor * max_indices,const Tensor & weight,const Tensor & indices,const Tensor & offset2bag,const Tensor & output,bool include_last_offset,Tensor & bag_size,int64_t padding_idx)1064 void embedding_bag_cpu_max_out(
1065 Tensor* max_indices,
1066 const Tensor& weight,
1067 const Tensor& indices,
1068 const Tensor& offset2bag,
1069 const Tensor& output,
1070 bool include_last_offset,
1071 Tensor& bag_size,
1072 int64_t padding_idx) {
1073 int64_t numIndices = indices.numel();
1074 int64_t featureSize = weight.size(1);
1075 int64_t vocab_size = weight.size(0);
1076 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max_out", [&] {
1077 auto* indices_data = indices.const_data_ptr<index_t>();
1078 auto* offset2bag_data = offset2bag.data_ptr<index_t>();
1079
1080 index_t* max_indices_data = nullptr;
1081 int64_t max_indices_stride = 0;
1082 if (max_indices) {
1083 max_indices_data = max_indices->data_ptr<index_t>();
1084 max_indices_stride = max_indices->strides()[0];
1085 }
1086
1087 auto* weight_data = weight.const_data_ptr<scalar_t>();
1088 auto* output_data = output.data_ptr<scalar_t>();
1089 auto* bag_size_data = bag_size.data_ptr<index_t>();
1090 auto weight_stride0 = weight.strides()[0];
1091 auto weight_stride1 = weight.strides()[1];
1092 auto output_stride = output.strides()[0];
1093 int64_t numBags = bag_size.size(0);
1094 std::vector<bool> bag_empty(numBags, true);
1095
1096 for (const auto i : c10::irange(numIndices)) {
1097 auto bag = offset2bag_data[i];
1098 auto word_idx = indices_data[i];
1099 TORCH_CHECK(
1100 word_idx >= 0 && word_idx < vocab_size,
1101 "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
1102 word_idx);
1103 if (word_idx != static_cast<index_t>(padding_idx)) {
1104 bool is_first_for_bag = bag_empty[bag];
1105 for (const auto dim : c10::irange(featureSize)) {
1106 auto& current_item = output_data[output_stride * bag + dim];
1107 auto weight_item =
1108 weight_data[weight_stride0 * word_idx + dim * weight_stride1];
1109
1110 if (is_first_for_bag || (weight_item > current_item)) {
1111 current_item = weight_item;
1112 if (max_indices_data) {
1113 max_indices_data[max_indices_stride * bag + dim] = word_idx;
1114 }
1115 }
1116 }
1117 if (is_first_for_bag) {
1118 bag_empty[bag] = false;
1119 }
1120 } else {
1121 // Decrement bag_size to reflect that the index is padded
1122 bag_size_data[bag]--;
1123 }
1124 }
1125 });
1126 }
1127
_embedding_bag_cpu_impl_out(Tensor & output,Tensor & offset2bag,Tensor & bag_size,Tensor * max_indices,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const int64_t mode,const std::optional<Tensor> & per_sample_weights,bool include_last_offset,int64_t padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)1128 void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
1129 Tensor& bag_size, Tensor* max_indices,
1130 const Tensor &weight, const Tensor &indices,
1131 const Tensor &offsets, const int64_t mode,
1132 const std::optional<Tensor>& per_sample_weights,
1133 bool include_last_offset, int64_t padding_idx, _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
1134 if (mode == EmbeddingBagMode::MEAN || mode == EmbeddingBagMode::SUM) {
1135 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_no_grad_cpu_out",
1136 [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() {
1137 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_no_grad_cpu_out",
1138 [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() {
1139 if (per_sample_weights.has_value() && per_sample_weights.value().defined()) {
1140 TORCH_INTERNAL_ASSERT(mode == EmbeddingBagMode::SUM);
1141 index_select_scale_add<scalar_t, index_t>(
1142 indices, offset2bag, per_sample_weights.value(), weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache);
1143 } else {
1144 index_select_add<scalar_t, index_t>(indices, offset2bag, weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache);
1145 }
1146 });
1147 });
1148 apply_bag_size(mode, output, bag_size);
1149 if (mode == EmbeddingBagMode::SUM) {
1150 // make bag_size output deterministic
1151 at::native::zero_(bag_size);
1152 }
1153 if (max_indices) {
1154 max_indices->copy_(bag_size);
1155 }
1156 } else { // EmbeddingBagMode::MAX
1157 AT_DISPATCH_FLOATING_TYPES_AND2(
1158 at::ScalarType::Half,
1159 at::ScalarType::BFloat16,
1160 weight.scalar_type(),
1161 "embedding_bag_cpu_max_out",
1162 [&]() {
1163 embedding_bag_cpu_max_out<scalar_t>(
1164 max_indices,
1165 weight,
1166 indices,
1167 offset2bag,
1168 output,
1169 include_last_offset,
1170 bag_size,
1171 padding_idx);
1172 });
1173 }
1174 }
1175
1176 // Assumes all input tensors except for `weight` are contiguous.
1177 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
_embedding_bag_cpu_impl(const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const int64_t mode,const Tensor & per_sample_weights,bool include_last_offset,int64_t padding_idx,bool requires_grad)1178 static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
1179 const Tensor& weight,
1180 const Tensor& indices_,
1181 const Tensor& offsets_,
1182 const int64_t mode,
1183 const Tensor& per_sample_weights,
1184 bool include_last_offset,
1185 int64_t padding_idx,
1186 bool requires_grad) {
1187 TORCH_CHECK(indices_.dim() == 1 || indices_.dim() == 2,
1188 "input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
1189 indices_.dim());
1190 if (indices_.dim() == 1) {
1191 TORCH_CHECK(offsets_.dim() == 1,
1192 "offsets has to be a 1D Tensor, but got Tensor of dimension ",
1193 offsets_.dim());
1194 }
1195 TORCH_CHECK(weight.dim() == 2,
1196 "weight has to be a 2D Tensor, but got Tensor of dimension ",
1197 weight.dim());
1198 auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1199 const auto& indices = *indicesMaybeOwned;
1200 const auto& offsets = *offsetsMaybeOwned;
1201 check_arguments(weight, indices, offsets, mode, per_sample_weights, include_last_offset);
1202
1203 Tensor output = at::empty(
1204 {include_last_offset ? offsets.size(0) - 1 : offsets.size(0),
1205 weight.sizes()[1]},
1206 weight.options());
1207
1208 Tensor offset2bag = make_offset2bag(output, weight, indices, offsets, mode, per_sample_weights, padding_idx);
1209
1210 Tensor bag_size = make_bag_size(offsets, indices, mode, include_last_offset, requires_grad);
1211
1212 Tensor max_indices = make_max_indices(weight, indices, offsets, bag_size, mode, include_last_offset);
1213
1214 _embedding_bag_cpu_impl_out(output, offset2bag,
1215 bag_size, &max_indices,
1216 weight, indices, offsets,
1217 mode, per_sample_weights,
1218 include_last_offset, padding_idx);
1219
1220 return std::make_tuple(std::move(output), std::move(offset2bag), std::move(bag_size), std::move(max_indices));
1221 }
1222
1223 // embedding_bag wrapper to enforce contiguity in tensors other than `weight`.
1224 // This is created to save extra `.contiguous()` call in backward.
1225 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
1226 std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,std::optional<int64_t> padding_idx_opt)1227 embedding_bag(const Tensor &weight, const Tensor &indices,
1228 const Tensor &offsets, const bool scale_grad_by_freq,
1229 const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1230 bool include_last_offset, std::optional<int64_t> padding_idx_opt) {
1231 // See [Note: hacky wrapper removal for optional tensor]
1232 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1233 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1234 int64_t padding_idx = -1;
1235
1236 if (padding_idx_opt.has_value()) {
1237 auto num_embeddings = weight.size(0);
1238 padding_idx = padding_idx_opt.value();
1239 TORCH_CHECK(
1240 (padding_idx >= -num_embeddings) && (padding_idx < num_embeddings),
1241 "padding_idx must be within the number of embeddings, -", num_embeddings,
1242 " through ", num_embeddings - 1, ", but got ", padding_idx);
1243 padding_idx = maybe_wrap_dim(padding_idx, weight.size(0));
1244 }
1245 std::tuple<Tensor, Tensor, Tensor, Tensor> out;
1246 if (!weight.requires_grad() && !weight._fw_grad(/*level=*/0).defined()) {
1247 out = at::_embedding_bag_forward_only(
1248 weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq,
1249 mode, sparse, per_sample_weights, include_last_offset, padding_idx);
1250 } else {
1251 out = at::_embedding_bag(
1252 weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq,
1253 mode, sparse, per_sample_weights, include_last_offset, padding_idx);
1254 }
1255 return out;
1256 };
1257
1258 std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset)1259 embedding_bag(const Tensor &weight, const Tensor &indices,
1260 const Tensor &offsets, const bool scale_grad_by_freq,
1261 const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1262 bool include_last_offset) {
1263 return at::native::embedding_bag(weight, indices, offsets, scale_grad_by_freq,
1264 mode, sparse, per_sample_weights_opt, include_last_offset, std::nullopt);
1265 }
1266
1267 // Assumes all input tensors except for `weight` are contiguous.
1268 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
1269 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_forward_only_cpu(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)1270 _embedding_bag_forward_only_cpu(const Tensor &weight, const Tensor &indices,
1271 const Tensor &offsets, const bool scale_grad_by_freq,
1272 const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt, bool include_last_offset,
1273 int64_t padding_idx) {
1274 // See [Note: hacky wrapper removal for optional tensor]
1275 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1276 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1277 std::ignore = scale_grad_by_freq;
1278 std::ignore = sparse;
1279 return _embedding_bag_cpu_impl(
1280 weight,
1281 indices,
1282 offsets,
1283 mode,
1284 per_sample_weights,
1285 include_last_offset,
1286 padding_idx,
1287 /*requires_grad=*/false);
1288 }
1289
1290 // Assumes all input tensors except for `weight` are contiguous.
1291 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
1292 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cpu(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)1293 _embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
1294 const Tensor &offsets, const bool scale_grad_by_freq,
1295 const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt, bool include_last_offset,
1296 int64_t padding_idx) {
1297 // See [Note: hacky wrapper removal for optional tensor]
1298 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1299 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1300
1301 std::ignore = scale_grad_by_freq;
1302 std::ignore = sparse;
1303 return _embedding_bag_cpu_impl(
1304 weight,
1305 indices,
1306 offsets,
1307 mode,
1308 per_sample_weights,
1309 include_last_offset,
1310 padding_idx,
1311 /*requires_grad=*/true);
1312 }
1313
_embedding_bag_cpu_out(at::Tensor & output,at::Tensor & offset2bag,at::Tensor & bag_size,at::Tensor * p_max_indices,const at::Tensor & weight,const at::Tensor & indices_,const at::Tensor & offsets_,const bool,const int64_t mode,const bool,const std::optional<at::Tensor> & per_sample_weights,const bool include_last_offset,const std::optional<int64_t> & padding_idx,_EmbeddingBagKernelCache * fbgemm_kernel_cache)1314 void _embedding_bag_cpu_out(
1315 at::Tensor& output,
1316 at::Tensor& offset2bag,
1317 at::Tensor& bag_size,
1318 at::Tensor* p_max_indices,
1319 const at::Tensor& weight,
1320 const at::Tensor& indices_,
1321 const at::Tensor& offsets_,
1322 const bool /* scale_grad_by_freq */,
1323 const int64_t mode,
1324 const bool /* sparse */,
1325 const std::optional<at::Tensor>& per_sample_weights,
1326 const bool include_last_offset,
1327 const std::optional<int64_t>& padding_idx,
1328 _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
1329 auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1330 const auto& indices = *indicesMaybeOwned;
1331 const auto& offsets = *offsetsMaybeOwned;
1332 at::native::check_arguments(
1333 weight, indices, offsets, mode, per_sample_weights, include_last_offset);
1334
1335 at::native::make_offset2bag_out(
1336 offset2bag,
1337 output,
1338 weight,
1339 indices,
1340 offsets,
1341 mode,
1342 per_sample_weights,
1343 padding_idx.value_or(-1));
1344
1345 at::native::make_bag_size_out(
1346 bag_size, offsets, indices, mode, include_last_offset, false);
1347
1348 if (p_max_indices) {
1349 at::native::make_max_indices_out(
1350 *p_max_indices,
1351 weight,
1352 indices,
1353 offsets,
1354 bag_size,
1355 mode,
1356 include_last_offset);
1357 }
1358
1359 at::native::_embedding_bag_cpu_impl_out(
1360 output,
1361 offset2bag,
1362 bag_size,
1363 p_max_indices,
1364 weight,
1365 indices,
1366 offsets,
1367 mode,
1368 per_sample_weights,
1369 include_last_offset,
1370 padding_idx.value_or(-1),
1371 fbgemm_kernel_cache);
1372 }
1373
_embedding_bag_backward(const Tensor & grad,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,const Tensor & bag_size_,const Tensor & max_indices_,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)1374 Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_,
1375 const Tensor &offsets_,
1376 const Tensor &offset2bag,
1377 const Tensor &bag_size_,
1378 const Tensor &max_indices_,
1379 int64_t num_weights,
1380 bool scale_grad_by_freq, int64_t mode,
1381 bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1382 int64_t padding_idx) {
1383 return at::native::_embedding_bag_backward_symint(
1384 grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_opt, padding_idx);
1385 }
1386
1387 // Assumes all input tensors are contiguous.
1388 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
_embedding_bag_backward_symint(const Tensor & grad,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,const Tensor & bag_size_,const Tensor & max_indices_,c10::SymInt num_weights,bool scale_grad_by_freq,int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)1389 Tensor _embedding_bag_backward_symint(const Tensor &grad, const Tensor &indices_,
1390 const Tensor &offsets_,
1391 const Tensor &offset2bag,
1392 const Tensor &bag_size_,
1393 const Tensor &max_indices_,
1394 c10::SymInt num_weights,
1395 bool scale_grad_by_freq, int64_t mode,
1396 bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
1397 int64_t padding_idx) {
1398 // See [Note: hacky wrapper removal for optional tensor]
1399 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1400 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1401
1402 auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1403 const auto& indices = *indicesMaybeOwned;
1404 const auto& offsets = *offsetsMaybeOwned;
1405 auto indices_arg = TensorArg(indices, "indices", 1);
1406 checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
1407 checkContiguous("embedding_bag", indices_arg);
1408 auto offsets_arg = TensorArg(offsets, "offsets", 1);
1409 checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
1410 checkSameType("embedding_bag", indices_arg, offsets_arg);
1411 checkContiguous("embedding_bag", offsets_arg);
1412
1413 Tensor offset2bag_;
1414 if (indices.sym_numel() != 0 && offset2bag.sym_numel() == 0) {
1415 offset2bag_ = offsets.new_zeros(
1416 {indices.size(0) + 1}, offsets.options()); // offset2bag = [0 0 0 0 0]
1417
1418 make_offset2bag(offsets, offset2bag_);
1419 // For Composite Compliance, if `offset2bag_` is CCT
1420 // then we can't call `resize_`. Instead we call `narrow`
1421 // to slice the tensor.
1422 if (isTensorSubclassLike(offset2bag_)) {
1423 offset2bag_ = offset2bag_.narrow(0, 0, indices.size(0));
1424 } else {
1425 offset2bag_.resize_({indices.size(0)});
1426 }
1427 } else {
1428 auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
1429 checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
1430 checkContiguous("embedding_bag", offset2bag_arg);
1431 offset2bag_ = offset2bag;
1432 }
1433
1434 if (sparse) {
1435 return at::_embedding_bag_sparse_backward_symint(
1436 grad, indices, offsets, offset2bag_, bag_size_, std::move(num_weights),
1437 scale_grad_by_freq, mode, per_sample_weights, padding_idx);
1438 } else {
1439 return at::_embedding_bag_dense_backward_symint(
1440 grad, indices, offset2bag_, bag_size_, max_indices_, std::move(num_weights),
1441 scale_grad_by_freq, mode, per_sample_weights, padding_idx);
1442 }
1443 }
1444
_embedding_bag_dense_backward_cpu_max(const Tensor & grad,const Tensor & bag_size,const Tensor & max_indices,int64_t num_weights)1445 static Tensor _embedding_bag_dense_backward_cpu_max(
1446 const Tensor& grad,
1447 const Tensor& bag_size,
1448 const Tensor& max_indices,
1449 int64_t num_weights) {
1450 AT_ASSERT(max_indices.defined());
1451 auto index_grad_weight =
1452 at::zeros({num_weights, grad.sizes()[1]}, grad.options());
1453 auto nonempty_max_indices = max_indices.index_select(0, bag_size.nonzero().view(-1));
1454 auto nonempty_grad = grad.index_select(0, bag_size.nonzero().view(-1));
1455
1456 for (const auto dim : c10::irange(grad.sizes()[1])) {
1457 index_grad_weight.select(1, dim).index_add_(
1458 0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim));
1459 }
1460 return index_grad_weight;
1461 }
1462
1463 template<typename index_t>
compute_counts(int64_t num_weights,const index_t * indices_data,int64_t indices_length)1464 static std::vector<index_t> compute_counts(
1465 int64_t num_weights,
1466 const index_t* indices_data,
1467 int64_t indices_length) {
1468 std::vector<index_t> counts(num_weights, 0);
1469 for (const auto i : c10::irange(indices_length)) {
1470 counts[indices_data[i]]++;
1471 }
1472 return counts;
1473 }
1474
1475 // counts_uniq stores the index of the NEXT unique element
1476 // of the (sorted) indices vector.
1477 //
1478 // For example:
1479 // indices: [0, 0, 0, 1, 3, 3, 4]
1480 // counts: [3, 1, 0, 2, 1, 0]
1481 // counts_uniq: [3, 4, 6, 7]
1482 //
1483 // The unique indices can be found at index 0, 3, 4, 6.
1484 template<typename index_t>
compute_counts_uniq(int64_t num_weights,const index_t * indices_data,int64_t indices_length,const std::vector<index_t> & counts)1485 static std::vector<index_t> compute_counts_uniq(
1486 int64_t num_weights,
1487 const index_t* indices_data,
1488 int64_t indices_length,
1489 const std::vector<index_t>& counts) {
1490 std::vector<index_t> counts_uniq;
1491 counts_uniq.reserve(num_weights);
1492 int64_t o = 0;
1493 for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) {
1494 counts_uniq.push_back(counts[indices_data[i]]);
1495 if (o > 0) {
1496 counts_uniq[o] += counts_uniq[o - 1];
1497 }
1498 o++;
1499 }
1500 return counts_uniq;
1501 }
1502
1503 template <typename scalar_t>
_embedding_bag_dense_backward_cpu_sum_mean(const Tensor & grad,const Tensor & indices_,const Tensor & offset2bag__,const Tensor & bag_size_,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const Tensor & per_sample_weights_,Tensor & index_grad_weight,int64_t padding_idx)1504 void _embedding_bag_dense_backward_cpu_sum_mean(
1505 const Tensor& grad,
1506 const Tensor& indices_,
1507 const Tensor& offset2bag__,
1508 const Tensor& bag_size_,
1509 int64_t num_weights,
1510 bool scale_grad_by_freq,
1511 int64_t mode,
1512 const Tensor& per_sample_weights_,
1513 Tensor& index_grad_weight,
1514 int64_t padding_idx) {
1515
1516 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1517 Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__);
1518
1519 auto ind_sort_ = indices_.sort();
1520 auto indices = std::get<0>(ind_sort_);
1521 auto ind_sort = std::get<1>(ind_sort_);
1522 auto offset2bag = offset2bag_.index_select(0, ind_sort);
1523
1524 std::optional<Tensor> per_sample_weights;
1525 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1526 const scalar_t* per_sample_weights_data;
1527 std::optional<int64_t> per_sample_weights_stride;
1528 if (per_sample_weights_.defined()) {
1529 per_sample_weights = per_sample_weights_.index_select(0, ind_sort);
1530 per_sample_weights_data = per_sample_weights->const_data_ptr<scalar_t>();
1531 per_sample_weights_stride = per_sample_weights->strides()[0];
1532 }
1533
1534 int64_t numel = indices.numel();
1535
1536 // explicitly capture all required variables to work around windows build
1537 // TODO: fix this when windows can correctly capture variables in nested lambda
1538 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean",
1539 [&indices, &offset2bag, &bag_size_, &num_weights, &numel, &per_sample_weights,
1540 &per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq,
1541 &grad, &index_grad_weight, &padding_idx] {
1542 auto* indices_data = indices.const_data_ptr<index_t>();
1543 auto* offset2bag_data = offset2bag.const_data_ptr<index_t>();
1544 auto* bag_size_data = bag_size_.const_data_ptr<index_t>();
1545
1546 auto counts = compute_counts(num_weights, indices_data, numel);
1547 auto next_unique_index_idx =
1548 compute_counts_uniq(num_weights, indices_data, numel, counts);
1549
1550 auto loop =
1551 [&next_unique_index_idx, &indices_data, &offset2bag_data, &bag_size_data, &per_sample_weights,
1552 &mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq,
1553 &counts, &grad, &index_grad_weight, &padding_idx
1554 ](index_t start, index_t end) {
1555 for (index_t i = start; i < end; i++) {
1556 index_t start = i == 0 ? 0 : next_unique_index_idx[i - 1];
1557 index_t index = indices_data[start];
1558
1559 if (index != static_cast<index_t>(padding_idx)) {
1560 for (index_t j = start; j < next_unique_index_idx[i]; j++) {
1561 index_t source = offset2bag_data[j];
1562 double scale = 1.0;
1563 if (per_sample_weights) {
1564 AT_ASSERT(mode == EmbeddingBagMode::SUM);
1565 scale = per_sample_weights_data[*per_sample_weights_stride * j];
1566 }
1567 if (scale_grad_by_freq) {
1568 scale /= counts[indices_data[i]];
1569 }
1570 if (mode == EmbeddingBagMode::MEAN) {
1571 auto bag_size = bag_size_data[source];
1572 if (bag_size != 0) {
1573 scale /= bag_size;
1574 }
1575 }
1576 int64_t ddim = grad.size(1);
1577 auto igwd = index_grad_weight.data_ptr<scalar_t>();
1578 auto gd = grad.const_data_ptr<scalar_t>();
1579 at::native::cpublas::axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1,
1580 igwd + ddim * index, 1);
1581 }
1582 }
1583 }
1584 };
1585
1586 if (numel > 1000) {
1587 at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
1588 } else {
1589 loop(0, (int64_t)next_unique_index_idx.size());
1590 }
1591 });
1592 }
1593
_embedding_bag_dense_backward_cpu(const Tensor & grad_,const Tensor & indices_,const Tensor & offset2bag__,const Tensor & bag_size_,const Tensor & max_indices_,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const std::optional<Tensor> & per_sample_weights__opt,int64_t padding_idx)1594 Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_,
1595 const Tensor &offset2bag__,
1596 const Tensor &bag_size_,
1597 const Tensor& max_indices_, int64_t num_weights,
1598 bool scale_grad_by_freq, int64_t mode, const std::optional<Tensor>& per_sample_weights__opt,
1599 int64_t padding_idx) {
1600 // See [Note: hacky wrapper removal for optional tensor]
1601 c10::MaybeOwned<Tensor> per_sample_weights__maybe_owned = at::borrow_from_optional_tensor(per_sample_weights__opt);
1602 const Tensor& per_sample_weights_ = *per_sample_weights__maybe_owned;
1603
1604 // indices_, offsets_ and offset2bag__ are assumed having correct dtypes and
1605 // contiguous here due to the checks in _embedding_bag_backward above.
1606 // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
1607 // for more details.
1608 auto grad = grad_.contiguous();
1609 auto grad_arg = TensorArg(grad, "grad_", 1);
1610 checkScalarTypes(
1611 "embedding_bag", grad_arg, {kHalf, kBFloat16, kFloat, kDouble});
1612
1613 if (mode == EmbeddingBagMode::MAX) {
1614 return _embedding_bag_dense_backward_cpu_max(
1615 grad_, bag_size_, max_indices_, num_weights);
1616 }
1617 AT_ASSERT(mode == EmbeddingBagMode::MEAN || mode == EmbeddingBagMode::SUM);
1618
1619 auto index_grad_weight =
1620 at::zeros({num_weights, grad.sizes()[1]}, grad.options());
1621
1622 AT_DISPATCH_FLOATING_TYPES_AND2(
1623 at::ScalarType::Half,
1624 at::ScalarType::BFloat16,
1625 grad.scalar_type(),
1626 "embedding_bag_backward",
1627 [&] {
1628 _embedding_bag_dense_backward_cpu_sum_mean<scalar_t>(
1629 grad,
1630 indices_,
1631 offset2bag__,
1632 bag_size_,
1633 num_weights,
1634 scale_grad_by_freq,
1635 mode,
1636 per_sample_weights_,
1637 index_grad_weight,
1638 padding_idx);
1639 });
1640 return index_grad_weight;
1641 }
1642
1643 template<typename scalar_t>
_embedding_bag_per_sample_weights_backward_cpu_template(const Tensor & grad,const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,int64_t mode,int64_t padding_idx)1644 Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
1645 const Tensor& grad,
1646 const Tensor& weight, // NB: embedding table, not per_sample_weights
1647 const Tensor& indices_,
1648 const Tensor& offsets_,
1649 const Tensor& offset2bag,
1650 int64_t mode,
1651 int64_t padding_idx) {
1652 TORCH_CHECK(
1653 mode == EmbeddingBagMode::SUM,
1654 "embedding_bag_backward: per_sample_weights only supported for mode='sum'");
1655
1656 AT_ASSERT(grad.dim() == 2);
1657 auto embedding_features = grad.sizes()[1];
1658
1659 auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
1660 const auto& indices = *indicesMaybeOwned;
1661 const auto& offsets = *offsetsMaybeOwned;
1662
1663 AT_ASSERT(indices.dim() == 1);
1664 auto num_samples = indices.size(0);
1665
1666 AT_ASSERT(weight.dim() == 2);
1667 AT_ASSERT(weight.sizes()[1] == embedding_features);
1668
1669 auto output = at::zeros({num_samples}, grad.options());
1670
1671 auto indices_arg = TensorArg(indices, "indices", 1);
1672 checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
1673 checkContiguous("embedding_bag", indices_arg);
1674
1675 Tensor offset2bag_;
1676 if (indices.numel() != 0 && offset2bag.numel() == 0) {
1677 offset2bag_ = at::zeros(
1678 {indices.size(0) + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0]
1679
1680 make_offset2bag(offsets, offset2bag_);
1681
1682 at::native::resize_(offset2bag_, {indices.size(0)}, std::nullopt);
1683 } else {
1684 auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
1685 checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
1686 checkContiguous("embedding_bag", offset2bag_arg);
1687 offset2bag_ = offset2bag;
1688 }
1689
1690 auto* grad_data = grad.const_data_ptr<scalar_t>();
1691 auto grad_stride0 = grad.strides()[0];
1692 auto grad_stride1 = grad.strides()[1];
1693
1694 auto* weight_data = weight.const_data_ptr<scalar_t>();
1695 auto weight_stride0 = weight.strides()[0];
1696 auto weight_stride1 = weight.strides()[1];
1697
1698 // explicitly capture all required variables to work around windows build
1699 // TODO: fix this when windows can correctly capture variables in nested lambda
1700 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template",
1701 [&indices, &output, &offset2bag_, &num_samples, &embedding_features,
1702 &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1,
1703 &padding_idx] () {
1704 auto* indices_data = indices.const_data_ptr<index_t>();
1705
1706 // The following are contiguous
1707 auto* output_data = output.data_ptr<scalar_t>();
1708 auto* offset2bag_data = offset2bag_.const_data_ptr<index_t>();
1709
1710 // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
1711 parallel_for(0, num_samples, 64,
1712 [&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0,
1713 &weight_stride1, &offset2bag_data, &indices_data, &output_data, &padding_idx](index_t begin, index_t end) {
1714 for (index_t sample_idx = begin; sample_idx < end; sample_idx++) {
1715 auto bag_idx = offset2bag_data[sample_idx];
1716 auto embedding_idx = indices_data[sample_idx];
1717
1718 if (embedding_idx != static_cast<index_t>(padding_idx)) {
1719 output_data[sample_idx] = dot_impl<scalar_t>(
1720 embedding_features,
1721 const_cast<scalar_t*>(grad_data + grad_stride0 * bag_idx), grad_stride1,
1722 const_cast<scalar_t*>(weight_data + weight_stride0 * embedding_idx), weight_stride1);
1723 }
1724 }
1725 });
1726 });
1727 return output;
1728 }
1729
_embedding_bag_per_sample_weights_backward_cpu(const Tensor & grad,const Tensor & weight,const Tensor & indices,const Tensor & offsets,const Tensor & offset2bag,int64_t mode,int64_t padding_idx)1730 Tensor _embedding_bag_per_sample_weights_backward_cpu(
1731 const Tensor& grad,
1732 const Tensor& weight, // NB: embedding table, not per_sample_weights
1733 const Tensor& indices,
1734 const Tensor& offsets,
1735 const Tensor& offset2bag,
1736 int64_t mode,
1737 int64_t padding_idx) {
1738 return AT_DISPATCH_FLOATING_TYPES_AND2(
1739 at::ScalarType::Half,
1740 at::ScalarType::BFloat16,
1741 grad.scalar_type(),
1742 "_embedding_bag_per_sample_weights_backward_cpu",
1743 [&]() {
1744 return _embedding_bag_per_sample_weights_backward_cpu_template<
1745 scalar_t>(
1746 grad, weight, indices, offsets, offset2bag, mode, padding_idx);
1747 });
1748 }
1749
_embedding_bag_sparse_backward_symint(const Tensor & grad_,const Tensor & indices,const Tensor & offsets,const Tensor & offset2bag,const Tensor & bag_size_,SymInt num_weights,bool scale_grad_by_freq,int64_t mode,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)1750 Tensor _embedding_bag_sparse_backward_symint(
1751 const Tensor &grad_, const Tensor &indices, const Tensor &offsets,
1752 const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights,
1753 bool scale_grad_by_freq, int64_t mode, const std::optional<Tensor>& per_sample_weights_opt,
1754 int64_t padding_idx) {
1755 // See [Note: hacky wrapper removal for optional tensor]
1756 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
1757 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
1758
1759 // indices, offsets and offset2bag are assumed having correct dtypes and
1760 // contiguous here due to the checks in _embedding_bag_backward above.
1761 // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
1762 // for more details.
1763
1764 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
1765 Tensor grad = grad_;
1766 Tensor index_grad = grad_.index_select(0, offset2bag);
1767
1768 index_grad = apply_bag_size_backward(mode, index_grad, offset2bag, bag_size_);
1769
1770 if (per_sample_weights.defined()) {
1771 AT_ASSERT(mode == EmbeddingBagMode::SUM);
1772 index_grad.mul_(per_sample_weights.unsqueeze(1));
1773 }
1774 return native::embedding_backward_symint(index_grad, indices, std::move(num_weights), padding_idx,
1775 scale_grad_by_freq, true);
1776 }
1777 } // namespace at::native
1778