xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cuda/EmbeddingBag.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/core/op_registration/op_registration.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/native/TensorIterator.h>
8 #include <ATen/quantized/Quantizer.h>
9 #include <c10/cuda/CUDAGuard.h>
10 #include <torch/library.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/arange.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_native.h>
19 #include <ATen/ops/resize_native.h>
20 #endif
21 
22 namespace at {
23 namespace native {
24 
25 // BEGIN QUANTIZE HELPER FUNCTIONS
bfe(uint32_t val,uint32_t pos,uint32_t len)26 __device__ __forceinline__ float bfe(uint32_t val, uint32_t pos, uint32_t len) {
27 #ifdef USE_ROCM
28   return *reinterpret_cast<float*>((val >> pos) && ((1u << len) - 1u ));
29 #else
30   uint32_t ret;
31   // Get the bit field of [pos, pos+len) bits from val:
32   // (val >> pos) && ( (1u << len) - 1u )
33   asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
34   return __uint2float_rn(ret);
35 #endif
36 }
37 
38 // FMA with constant scale/bias for all 4 floats in fa
39 __forceinline__ __device__ float4
fma4sb(const float4 fa,const float fscale,const float fbias)40 fma4sb(const float4 fa, const float fscale, const float fbias) {
41   float4 res;
42 #ifdef USE_ROCM
43   res.x = fa.x * fscale + fbias;
44   res.y = fa.y * fscale + fbias;
45   res.z = fa.z * fscale + fbias;
46   res.w = fa.w * fscale + fbias;
47 #else
48   res.x = fmaf(fa.x, fscale, fbias);
49   res.y = fmaf(fa.y, fscale, fbias);
50   res.z = fmaf(fa.z, fscale, fbias);
51   res.w = fmaf(fa.w, fscale, fbias);
52 #endif
53   return res;
54 }
55 
56 template <uint8_t bits_per_dim>
57 __forceinline__ __device__ float4
dequantize_intx(uint32_t packedVals,float2 scale_bias,uint8_t offset_bits)58 dequantize_intx(uint32_t packedVals, float2 scale_bias, uint8_t offset_bits) {
59   float4 res;
60 
61   res.x = bfe(packedVals, offset_bits + (0 * bits_per_dim), bits_per_dim);
62   res.y = bfe(packedVals, offset_bits + (1 * bits_per_dim), bits_per_dim);
63   res.z = bfe(packedVals, offset_bits + (2 * bits_per_dim), bits_per_dim);
64   res.w = bfe(packedVals, offset_bits + (3 * bits_per_dim), bits_per_dim);
65 
66   return fma4sb(res, scale_bias.x, scale_bias.y);
67 }
68 
69 template <uint8_t bits_per_dim>
70 __forceinline__ __device__ void
accumulate_packed_intx(float4 * acc,uint32_t packedVals,float2 scale_bias,float sample_weight)71 accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias, float sample_weight) {
72   constexpr uint8_t dims_per_byte = 8 / bits_per_dim;
73   for (uint8_t i = 0; i < dims_per_byte; i++) {
74     float4 res = dequantize_intx<bits_per_dim>(packedVals, scale_bias, 4 * bits_per_dim * i /* offset_bits */);
75     // Accumulate in float32.
76     acc[i].x += (res.x * sample_weight);
77     acc[i].y += (res.y * sample_weight);
78     acc[i].z += (res.z * sample_weight);
79     acc[i].w += (res.w * sample_weight);
80   }
81 }
82 
83 // END QUANTIZE HELPER FUNCTIONS
84 
85 // UN-OPTIMIZED kernel, doesn't even avoid warp divergence!
86 template <typename index_t, uint8_t bits_per_dim>
embedding_bag_nbits_rowwise_offsets_kernel(const PackedTensorAccessor64<uint8_t,2,RestrictPtrTraits> weight,const PackedTensorAccessor32<index_t,1,RestrictPtrTraits> indices,const PackedTensorAccessor32<index_t,1,RestrictPtrTraits> offsets,const bool,const PackedTensorAccessor32<float,1,RestrictPtrTraits> per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,const bool include_last_offset,PackedTensorAccessor32<float,2,RestrictPtrTraits> output)87 __global__ void embedding_bag_nbits_rowwise_offsets_kernel(
88     const PackedTensorAccessor64<uint8_t, 2, RestrictPtrTraits> weight,
89     const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> indices,
90     const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> offsets,
91     const bool /* pruned_weights */,
92     const PackedTensorAccessor32<float, 1, RestrictPtrTraits> per_sample_weights_,
93     const std::optional<Tensor>& compressed_indices_mapping,
94     const bool include_last_offset,
95     PackedTensorAccessor32<float, 2, RestrictPtrTraits> output) {
96   static_assert(bits_per_dim == 4 || bits_per_dim == 8, "the current embedding_bag_nbits_rowwise_offsets_kernel only has been tested for 4 and 8 bits per dim");
97   constexpr uint8_t dims_per_byte = 8 / bits_per_dim;
98   constexpr bool fp32_scale_bias = bits_per_dim == 8;
99 
100   int32_t B = output.size(0);
101   int32_t D = output.size(1);
102   int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
103   if (b_t >= B * D) {
104     return;
105   }
106   int32_t t = b_t / B;
107   int32_t b = b_t % B;
108 
109   const int32_t D_bytes = weight.size(1);
110 
111   bool use_per_sample = per_sample_weights_.size(0) > 0;
112 
113   int64_t indices_start = offsets[t * B + b];
114   int64_t indices_end;
115   if (include_last_offset) {
116     indices_end = offsets[t * B + b + 1];
117   } else {
118     indices_end = (t * B + b + 1) < offsets.size(0) ? offsets[t * B + b + 1]
119                                                     : indices.size(0);
120   }
121 
122   int32_t L = indices_end - indices_start;
123   const uint8_t* __restrict__ weights = &weight[0][0];
124 
125   if (L == 0) {
126     for (int32_t d = 0; d < D; d += 4) {
127       *(float4*)(&output[b][d]) = make_float4(0, 0, 0, 0);
128     }
129     return;
130   }
131 
132 
133   float4 accumulator[dims_per_byte];
134   int32_t byte_offset = 0;
135   for (int32_t d = 0; d < D; d += dims_per_byte * 4, byte_offset += 4) {
136     for (int32_t i = 0; i < dims_per_byte; ++i) {
137         accumulator[i] = make_float4(0, 0, 0, 0);
138     }
139     for (int32_t l = indices_start; l < indices_end; ++l) {
140       int64_t idx = indices[l];
141       float sample_weight = use_per_sample ? per_sample_weights_[l] : 1.0f;
142       const uint8_t* __restrict__ row = &weights[idx * D_bytes];
143       float2 scale_bias;
144       if (fp32_scale_bias) {
145         scale_bias = make_float2(
146             reinterpret_cast<const float*>(&row[D_bytes - 8])[0],
147             reinterpret_cast<const float*>(&row[D_bytes - 4])[0]);
148       } else {
149         scale_bias = make_float2(
150             __half2float(reinterpret_cast<const __half*>(&row[D_bytes - 4])[0]),
151             __half2float(reinterpret_cast<const __half*>(&row[D_bytes - 2])[0]));
152       }
153 
154       uint32_t v0 = reinterpret_cast<const uint32_t*>(&row[byte_offset])[0];
155 
156       accumulate_packed_intx<bits_per_dim>(accumulator, v0, scale_bias, sample_weight);
157     }
158 
159 
160     for (int32_t i = 0; i < dims_per_byte; ++i) {
161       *(float4*)(&output[b][d + (i * 4)]) = accumulator[i];
162     }
163   }
164 }
165 
create_empty_from(const at::Tensor & t,c10::ScalarType dtype)166 inline at::Tensor create_empty_from(
167     const at::Tensor& t,
168     c10::ScalarType dtype) {
169   return at::native::empty_cuda({0}, dtype, t.layout(), t.device(), false);
170 }
171 
qembeddingbag_byte_unpack(const Tensor & packed_weight)172 Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) {
173   const auto packed_weight_sizes = packed_weight.sizes();
174   const auto col_dim = packed_weight_sizes.size() - 1;
175   const int32_t input_rows = c10::size_to_dim_(col_dim, packed_weight_sizes);
176   const int32_t input_columns = packed_weight_sizes[col_dim];
177   const int32_t output_columns = input_columns - 2 * sizeof(float);
178 
179   std::vector<int64_t> output_shape = packed_weight_sizes.vec();
180   output_shape[col_dim] = output_columns;
181 
182   return at::empty(
183       output_shape,
184       packed_weight.options().dtype(kFloat),
185       packed_weight.suggest_memory_format());
186 }
187 
188 template <typename IndexType, typename OffsetType>
embedding_bag_byte_impl(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)189 at::Tensor& embedding_bag_byte_impl(
190     at::Tensor& output,
191     const at::Tensor& weight,
192     const at::Tensor& indices,
193     const at::Tensor& offsets,
194     bool pruned_weights,
195     const std::optional<at::Tensor>& per_sample_weights_,
196     const std::optional<at::Tensor>& compressed_indices_mapping,
197     bool include_last_offset,
198     bool is_embedding_op) {
199   TORCH_CHECK(weight.is_cuda());
200   TORCH_CHECK(indices.is_cuda());
201   TORCH_CHECK(offsets.is_cuda());
202   TORCH_CHECK(indices.device() == weight.device())
203   TORCH_CHECK(offsets.device() == weight.device());
204   if (per_sample_weights_.has_value()) {
205     TORCH_CHECK(per_sample_weights_.value().device() == weight.device());
206   }
207   if (compressed_indices_mapping.has_value()) {
208     TORCH_CHECK(compressed_indices_mapping.value().device() == weight.device());
209   }
210 
211   TORCH_CHECK(weight.dtype() == at::kByte);
212   TORCH_CHECK(weight.dim() == 2);
213 
214   at::cuda::OptionalCUDAGuard device_guard;
215   device_guard.set_index(weight.get_device());
216 
217   const auto weight_sizes = weight.sizes();
218   const int64_t N = weight_sizes[0];
219   const int D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias
220   const int64_t M = offsets.sizes()[0];
221   TORCH_CHECK(D % 4 == 0);
222   if(per_sample_weights_.has_value()) {
223       TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
224               "Per sample weights expected scalar type ", at::kFloat, " but got ",
225               per_sample_weights_.value().scalar_type());
226   }
227   TORCH_CHECK(
228       !compressed_indices_mapping.has_value(),
229       "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
230 
231   const auto maxThreads = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
232 
233   int64_t output_size = include_last_offset ? M - 1 : M;
234 
235   at::Tensor sample_weights;
236   if (per_sample_weights_.has_value()) {
237       sample_weights = per_sample_weights_.value();
238   } else {
239       sample_weights = create_empty_from(output, kFloat);
240   }
241 
242   const std::vector<int64_t> shape = {output_size, D};
243   at::native::resize_(output, shape, std::nullopt);
244   AT_DISPATCH_INDEX_TYPES(
245       indices.scalar_type(), "embedding_bag_byte_rowwise_offsets_kernel", ([&] {
246         embedding_bag_nbits_rowwise_offsets_kernel<index_t, 8><<<
247             output_size,
248             dim3(1, 1, 1),
249             0,
250             at::cuda::getCurrentCUDAStream()>>>(
251             weight.packed_accessor64<uint8_t, 2, RestrictPtrTraits>(),
252             indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
253             offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
254             false /* pruned_weights */,
255             sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
256             compressed_indices_mapping,
257             include_last_offset,
258             output.packed_accessor32<float, 2, RestrictPtrTraits>());
259         C10_CUDA_KERNEL_LAUNCH_CHECK();
260       }));
261 
262   TORCH_CHECK(output.is_cuda());
263 
264   return output;
265 }
266 
embedding_bag_byte_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)267 Tensor embedding_bag_byte_rowwise_offsets(
268     const Tensor& weight,
269     const Tensor& indices,
270     const std::optional<Tensor>& offsets_in,
271     const bool /* scale_grad_by_freq */,
272     const int64_t /* mode */,
273     bool pruned_weights,
274     const std::optional<Tensor>& per_sample_weights_,
275     const std::optional<Tensor>& compressed_indices_mapping,
276     bool include_last_offset) {
277   bool is_embedding_op = false;
278   auto output = create_empty_from(weight, at::kFloat);
279 
280   c10::MaybeOwned<at::Tensor> offsets;
281   TORCH_CHECK(
282       indices.dim() == 1 || indices.dim() == 2,
283       "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
284       indices.dim());
285   // For embedding_bag operator with 2D indices, we set the offsets explicitly
286   // here.
287   if (indices.dim() == 2 && !is_embedding_op) {
288     TORCH_CHECK(
289         !offsets_in.has_value(),
290         "embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
291 
292     offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
293         0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
294 
295   } else {
296     TORCH_CHECK(
297         offsets_in.has_value(),
298         "embedding_bag_byte expects offsets to be set for 1D indices.");
299     offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
300   }
301 
302   TORCH_CHECK(
303       indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
304       "Expect 32 or 64 bit indices, but found ",
305       indices.scalar_type(),
306       " instead.");
307   TORCH_CHECK(
308       offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
309       "Expect 32 or 64 bit offsets, but found ",
310       offsets->scalar_type(),
311       " instead.");
312   TORCH_CHECK(
313       weight.is_contiguous() && indices.is_contiguous() &&
314           offsets->is_contiguous(),
315       "Expect weight, indices, and offsets to be contiguous.");
316 
317   // Using helper function to support different type combination without the
318   // need to cast, which can be additional performance overhead
319   if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
320     return embedding_bag_byte_impl<int, int>(
321         output,
322         weight,
323         indices,
324         *offsets,
325         pruned_weights,
326         per_sample_weights_,
327         compressed_indices_mapping,
328         include_last_offset,
329         is_embedding_op);
330   } else if (
331       indices.scalar_type() == at::kInt &&
332       offsets->scalar_type() == at::kLong) {
333     return embedding_bag_byte_impl<int, int64_t>(
334         output,
335         weight,
336         indices,
337         *offsets,
338         pruned_weights,
339         per_sample_weights_,
340         compressed_indices_mapping,
341         include_last_offset,
342         is_embedding_op);
343   } else if (
344       indices.scalar_type() == at::kLong &&
345       offsets->scalar_type() == at::kInt) {
346     return embedding_bag_byte_impl<int64_t, int>(
347         output,
348         weight,
349         indices,
350         *offsets,
351         pruned_weights,
352         per_sample_weights_,
353         compressed_indices_mapping,
354         include_last_offset,
355         is_embedding_op);
356   }
357 
358   // default case given the TORCH_CHECK above
359   return embedding_bag_byte_impl<int64_t, int64_t>(
360       output,
361       weight,
362       indices,
363       *offsets,
364       pruned_weights,
365       per_sample_weights_,
366       compressed_indices_mapping,
367       include_last_offset,
368       is_embedding_op);
369 }
370 
371 template <typename IndexType, typename OffsetType>
embedding_bag_4bit_impl(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset)372 at::Tensor& embedding_bag_4bit_impl(
373     at::Tensor& output,
374     const at::Tensor& weight,
375     const at::Tensor& indices,
376     const at::Tensor& offsets,
377     bool pruned_weights,
378     const std::optional<at::Tensor>& per_sample_weights_,
379     const std::optional<at::Tensor>& compressed_indices_mapping,
380     bool include_last_offset) {
381   TORCH_CHECK(weight.is_cuda());
382   TORCH_CHECK(indices.is_cuda());
383   TORCH_CHECK(offsets.is_cuda());
384   TORCH_CHECK(indices.device() == weight.device())
385   TORCH_CHECK(offsets.device() == weight.device());
386   if (per_sample_weights_.has_value()) {
387     TORCH_CHECK(per_sample_weights_.value().device() == weight.device());
388   }
389   if (compressed_indices_mapping.has_value()) {
390     TORCH_CHECK(compressed_indices_mapping.value().device() == weight.device());
391   }
392 
393   TORCH_CHECK(weight.dtype() == at::kByte);
394   TORCH_CHECK(weight.dim() == 2);
395 
396   at::cuda::OptionalCUDAGuard device_guard;
397   device_guard.set_index(weight.get_device());
398 
399   const auto weight_sizes = weight.sizes();
400   const int64_t N = weight_sizes[0];
401   const int D = 2*(weight_sizes[1] - 4); // NB: -4 to account for scale and bias @fp16
402   const int64_t M = offsets.sizes()[0];
403   TORCH_CHECK(D % 8 == 0);
404   if(per_sample_weights_.has_value()) {
405       TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
406               "Per sample weights expected scalar type ", at::kFloat, " but got ",
407               per_sample_weights_.value().scalar_type());
408   }
409   TORCH_CHECK(
410       !compressed_indices_mapping.has_value(),
411       "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
412 
413   const auto maxThreads = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
414 
415   int64_t output_size = include_last_offset ? M - 1 : M;
416 
417   at::Tensor sample_weights;
418   if (per_sample_weights_.has_value()) {
419       sample_weights = per_sample_weights_.value();
420   } else {
421       sample_weights = create_empty_from(output, kFloat);
422   }
423 
424   const std::vector<int64_t> shape = {output_size, D};
425   at::native::resize_(output, shape, std::nullopt);
426   AT_DISPATCH_INDEX_TYPES(
427       indices.scalar_type(), "embedding_bag_4bit_rowwise_offsets_kernel", ([&] {
428         embedding_bag_nbits_rowwise_offsets_kernel<index_t, 4><<<
429             output_size,
430             dim3(1, 1, 1),
431             0,
432             at::cuda::getCurrentCUDAStream()>>>(
433             weight.packed_accessor64<uint8_t, 2, RestrictPtrTraits>(),
434             indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
435             offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
436             false /* pruned_weights */,
437             sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
438             compressed_indices_mapping,
439             include_last_offset,
440             output.packed_accessor32<float, 2, RestrictPtrTraits>());
441         C10_CUDA_KERNEL_LAUNCH_CHECK();
442       }));
443 
444   TORCH_CHECK(output.is_cuda());
445 
446   return output;
447 }
448 
embedding_bag_4bit_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)449 Tensor embedding_bag_4bit_rowwise_offsets(
450     const Tensor& weight,
451     const Tensor& indices,
452     const std::optional<Tensor>& offsets_in,
453     const bool /* scale_grad_by_freq */,
454     const int64_t /* mode */,
455     bool pruned_weights,
456     const std::optional<Tensor>& per_sample_weights_,
457     const std::optional<Tensor>& compressed_indices_mapping,
458     bool include_last_offset) {
459   auto output = create_empty_from(weight, at::kFloat);
460 
461   c10::MaybeOwned<at::Tensor> offsets;
462   TORCH_CHECK(
463       indices.dim() == 1 || indices.dim() == 2,
464       "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
465       indices.dim());
466 
467   // For embedding_bag operator with 2D indices, we need to set the offsets
468   // explicitly here.
469   if (indices.dim() == 2) {
470     TORCH_CHECK(
471         !offsets_in.has_value(),
472         "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
473 
474     offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
475         0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
476   } else {
477     TORCH_CHECK(
478         offsets_in.has_value(),
479         "embedding_bag_4bit operator expects offsets to be set for 1D indices.");
480     offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
481   }
482 
483   TORCH_CHECK(
484       indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
485       "Expect 32 or 64 bit indices, but found ",
486       indices.scalar_type(),
487       " instead.");
488   TORCH_CHECK(
489       offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
490       "Expect 32 or 64 bit offsets, but found ",
491       offsets->scalar_type(),
492       " instead.");
493   TORCH_CHECK(
494       weight.is_contiguous() && indices.is_contiguous() &&
495           offsets->is_contiguous(),
496       "Expect weight, indices, and offsets to be contiguous.");
497 
498   if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
499     return embedding_bag_4bit_impl<int, int>(
500         output,
501         weight,
502         indices,
503         *offsets,
504         pruned_weights,
505         per_sample_weights_,
506         compressed_indices_mapping,
507         include_last_offset);
508   } else if (
509       indices.scalar_type() == at::kInt &&
510       offsets->scalar_type() == at::kLong) {
511     return embedding_bag_4bit_impl<int, int64_t>(
512         output,
513         weight,
514         indices,
515         *offsets,
516         pruned_weights,
517         per_sample_weights_,
518         compressed_indices_mapping,
519         include_last_offset);
520   } else if (
521       indices.scalar_type() == at::kLong &&
522       offsets->scalar_type() == at::kInt) {
523     return embedding_bag_4bit_impl<int64_t, int>(
524         output,
525         weight,
526         indices,
527         *offsets,
528         pruned_weights,
529         per_sample_weights_,
530         compressed_indices_mapping,
531         include_last_offset);
532   }
533   return embedding_bag_4bit_impl<int64_t, int64_t>(
534       output,
535       weight,
536       indices,
537       *offsets,
538       pruned_weights,
539       per_sample_weights_,
540       compressed_indices_mapping,
541       include_last_offset);
542 }
543 
qembeddingbag_4bit_unpack(const Tensor & packed_weight)544 Tensor qembeddingbag_4bit_unpack(const Tensor& packed_weight) {
545   int BIT_RATE = 4;
546   const auto input_rows = packed_weight.size(0);
547   const auto input_columns = packed_weight.size(1);
548   const auto* input_data = packed_weight.const_data_ptr<uint8_t>();
549   int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
550 
551   // The last 4 bytes per row are two fp16 scale and zero_point.
552   // The rest of input_columns is the number of values in the original row.
553   std::vector<int64_t> output_dimensions = {
554       input_rows,
555       static_cast<std::int64_t>(input_columns - 2 * sizeof(at::Half)) *
556           NUM_ELEM_PER_BYTE};
557 
558   auto output = at::empty(
559       output_dimensions,
560       packed_weight.options().dtype(kFloat),
561       packed_weight.suggest_memory_format());
562   return output;
563 }
564 
TORCH_LIBRARY_IMPL(quantized,CUDA,m)565 TORCH_LIBRARY_IMPL(quantized, CUDA, m) {
566   m.impl(
567       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"),
568       TORCH_FN(qembeddingbag_byte_unpack));
569   m.impl(
570       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_rowwise_offsets"),
571       TORCH_FN(embedding_bag_byte_rowwise_offsets));
572   m.impl(
573       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"),
574       TORCH_FN(qembeddingbag_4bit_unpack));
575   m.impl(
576       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"),
577       TORCH_FN(embedding_bag_4bit_rowwise_offsets));
578 }
579 
580 } // namespace native
581 } // namespace at
582