1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/core/op_registration/op_registration.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/native/TensorIterator.h>
8 #include <ATen/quantized/Quantizer.h>
9 #include <c10/cuda/CUDAGuard.h>
10 #include <torch/library.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/arange.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_native.h>
19 #include <ATen/ops/resize_native.h>
20 #endif
21
22 namespace at {
23 namespace native {
24
25 // BEGIN QUANTIZE HELPER FUNCTIONS
bfe(uint32_t val,uint32_t pos,uint32_t len)26 __device__ __forceinline__ float bfe(uint32_t val, uint32_t pos, uint32_t len) {
27 #ifdef USE_ROCM
28 return *reinterpret_cast<float*>((val >> pos) && ((1u << len) - 1u ));
29 #else
30 uint32_t ret;
31 // Get the bit field of [pos, pos+len) bits from val:
32 // (val >> pos) && ( (1u << len) - 1u )
33 asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
34 return __uint2float_rn(ret);
35 #endif
36 }
37
38 // FMA with constant scale/bias for all 4 floats in fa
39 __forceinline__ __device__ float4
fma4sb(const float4 fa,const float fscale,const float fbias)40 fma4sb(const float4 fa, const float fscale, const float fbias) {
41 float4 res;
42 #ifdef USE_ROCM
43 res.x = fa.x * fscale + fbias;
44 res.y = fa.y * fscale + fbias;
45 res.z = fa.z * fscale + fbias;
46 res.w = fa.w * fscale + fbias;
47 #else
48 res.x = fmaf(fa.x, fscale, fbias);
49 res.y = fmaf(fa.y, fscale, fbias);
50 res.z = fmaf(fa.z, fscale, fbias);
51 res.w = fmaf(fa.w, fscale, fbias);
52 #endif
53 return res;
54 }
55
56 template <uint8_t bits_per_dim>
57 __forceinline__ __device__ float4
dequantize_intx(uint32_t packedVals,float2 scale_bias,uint8_t offset_bits)58 dequantize_intx(uint32_t packedVals, float2 scale_bias, uint8_t offset_bits) {
59 float4 res;
60
61 res.x = bfe(packedVals, offset_bits + (0 * bits_per_dim), bits_per_dim);
62 res.y = bfe(packedVals, offset_bits + (1 * bits_per_dim), bits_per_dim);
63 res.z = bfe(packedVals, offset_bits + (2 * bits_per_dim), bits_per_dim);
64 res.w = bfe(packedVals, offset_bits + (3 * bits_per_dim), bits_per_dim);
65
66 return fma4sb(res, scale_bias.x, scale_bias.y);
67 }
68
69 template <uint8_t bits_per_dim>
70 __forceinline__ __device__ void
accumulate_packed_intx(float4 * acc,uint32_t packedVals,float2 scale_bias,float sample_weight)71 accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias, float sample_weight) {
72 constexpr uint8_t dims_per_byte = 8 / bits_per_dim;
73 for (uint8_t i = 0; i < dims_per_byte; i++) {
74 float4 res = dequantize_intx<bits_per_dim>(packedVals, scale_bias, 4 * bits_per_dim * i /* offset_bits */);
75 // Accumulate in float32.
76 acc[i].x += (res.x * sample_weight);
77 acc[i].y += (res.y * sample_weight);
78 acc[i].z += (res.z * sample_weight);
79 acc[i].w += (res.w * sample_weight);
80 }
81 }
82
83 // END QUANTIZE HELPER FUNCTIONS
84
85 // UN-OPTIMIZED kernel, doesn't even avoid warp divergence!
86 template <typename index_t, uint8_t bits_per_dim>
embedding_bag_nbits_rowwise_offsets_kernel(const PackedTensorAccessor64<uint8_t,2,RestrictPtrTraits> weight,const PackedTensorAccessor32<index_t,1,RestrictPtrTraits> indices,const PackedTensorAccessor32<index_t,1,RestrictPtrTraits> offsets,const bool,const PackedTensorAccessor32<float,1,RestrictPtrTraits> per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,const bool include_last_offset,PackedTensorAccessor32<float,2,RestrictPtrTraits> output)87 __global__ void embedding_bag_nbits_rowwise_offsets_kernel(
88 const PackedTensorAccessor64<uint8_t, 2, RestrictPtrTraits> weight,
89 const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> indices,
90 const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> offsets,
91 const bool /* pruned_weights */,
92 const PackedTensorAccessor32<float, 1, RestrictPtrTraits> per_sample_weights_,
93 const std::optional<Tensor>& compressed_indices_mapping,
94 const bool include_last_offset,
95 PackedTensorAccessor32<float, 2, RestrictPtrTraits> output) {
96 static_assert(bits_per_dim == 4 || bits_per_dim == 8, "the current embedding_bag_nbits_rowwise_offsets_kernel only has been tested for 4 and 8 bits per dim");
97 constexpr uint8_t dims_per_byte = 8 / bits_per_dim;
98 constexpr bool fp32_scale_bias = bits_per_dim == 8;
99
100 int32_t B = output.size(0);
101 int32_t D = output.size(1);
102 int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
103 if (b_t >= B * D) {
104 return;
105 }
106 int32_t t = b_t / B;
107 int32_t b = b_t % B;
108
109 const int32_t D_bytes = weight.size(1);
110
111 bool use_per_sample = per_sample_weights_.size(0) > 0;
112
113 int64_t indices_start = offsets[t * B + b];
114 int64_t indices_end;
115 if (include_last_offset) {
116 indices_end = offsets[t * B + b + 1];
117 } else {
118 indices_end = (t * B + b + 1) < offsets.size(0) ? offsets[t * B + b + 1]
119 : indices.size(0);
120 }
121
122 int32_t L = indices_end - indices_start;
123 const uint8_t* __restrict__ weights = &weight[0][0];
124
125 if (L == 0) {
126 for (int32_t d = 0; d < D; d += 4) {
127 *(float4*)(&output[b][d]) = make_float4(0, 0, 0, 0);
128 }
129 return;
130 }
131
132
133 float4 accumulator[dims_per_byte];
134 int32_t byte_offset = 0;
135 for (int32_t d = 0; d < D; d += dims_per_byte * 4, byte_offset += 4) {
136 for (int32_t i = 0; i < dims_per_byte; ++i) {
137 accumulator[i] = make_float4(0, 0, 0, 0);
138 }
139 for (int32_t l = indices_start; l < indices_end; ++l) {
140 int64_t idx = indices[l];
141 float sample_weight = use_per_sample ? per_sample_weights_[l] : 1.0f;
142 const uint8_t* __restrict__ row = &weights[idx * D_bytes];
143 float2 scale_bias;
144 if (fp32_scale_bias) {
145 scale_bias = make_float2(
146 reinterpret_cast<const float*>(&row[D_bytes - 8])[0],
147 reinterpret_cast<const float*>(&row[D_bytes - 4])[0]);
148 } else {
149 scale_bias = make_float2(
150 __half2float(reinterpret_cast<const __half*>(&row[D_bytes - 4])[0]),
151 __half2float(reinterpret_cast<const __half*>(&row[D_bytes - 2])[0]));
152 }
153
154 uint32_t v0 = reinterpret_cast<const uint32_t*>(&row[byte_offset])[0];
155
156 accumulate_packed_intx<bits_per_dim>(accumulator, v0, scale_bias, sample_weight);
157 }
158
159
160 for (int32_t i = 0; i < dims_per_byte; ++i) {
161 *(float4*)(&output[b][d + (i * 4)]) = accumulator[i];
162 }
163 }
164 }
165
create_empty_from(const at::Tensor & t,c10::ScalarType dtype)166 inline at::Tensor create_empty_from(
167 const at::Tensor& t,
168 c10::ScalarType dtype) {
169 return at::native::empty_cuda({0}, dtype, t.layout(), t.device(), false);
170 }
171
qembeddingbag_byte_unpack(const Tensor & packed_weight)172 Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) {
173 const auto packed_weight_sizes = packed_weight.sizes();
174 const auto col_dim = packed_weight_sizes.size() - 1;
175 const int32_t input_rows = c10::size_to_dim_(col_dim, packed_weight_sizes);
176 const int32_t input_columns = packed_weight_sizes[col_dim];
177 const int32_t output_columns = input_columns - 2 * sizeof(float);
178
179 std::vector<int64_t> output_shape = packed_weight_sizes.vec();
180 output_shape[col_dim] = output_columns;
181
182 return at::empty(
183 output_shape,
184 packed_weight.options().dtype(kFloat),
185 packed_weight.suggest_memory_format());
186 }
187
188 template <typename IndexType, typename OffsetType>
embedding_bag_byte_impl(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)189 at::Tensor& embedding_bag_byte_impl(
190 at::Tensor& output,
191 const at::Tensor& weight,
192 const at::Tensor& indices,
193 const at::Tensor& offsets,
194 bool pruned_weights,
195 const std::optional<at::Tensor>& per_sample_weights_,
196 const std::optional<at::Tensor>& compressed_indices_mapping,
197 bool include_last_offset,
198 bool is_embedding_op) {
199 TORCH_CHECK(weight.is_cuda());
200 TORCH_CHECK(indices.is_cuda());
201 TORCH_CHECK(offsets.is_cuda());
202 TORCH_CHECK(indices.device() == weight.device())
203 TORCH_CHECK(offsets.device() == weight.device());
204 if (per_sample_weights_.has_value()) {
205 TORCH_CHECK(per_sample_weights_.value().device() == weight.device());
206 }
207 if (compressed_indices_mapping.has_value()) {
208 TORCH_CHECK(compressed_indices_mapping.value().device() == weight.device());
209 }
210
211 TORCH_CHECK(weight.dtype() == at::kByte);
212 TORCH_CHECK(weight.dim() == 2);
213
214 at::cuda::OptionalCUDAGuard device_guard;
215 device_guard.set_index(weight.get_device());
216
217 const auto weight_sizes = weight.sizes();
218 const int64_t N = weight_sizes[0];
219 const int D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias
220 const int64_t M = offsets.sizes()[0];
221 TORCH_CHECK(D % 4 == 0);
222 if(per_sample_weights_.has_value()) {
223 TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
224 "Per sample weights expected scalar type ", at::kFloat, " but got ",
225 per_sample_weights_.value().scalar_type());
226 }
227 TORCH_CHECK(
228 !compressed_indices_mapping.has_value(),
229 "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
230
231 const auto maxThreads = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
232
233 int64_t output_size = include_last_offset ? M - 1 : M;
234
235 at::Tensor sample_weights;
236 if (per_sample_weights_.has_value()) {
237 sample_weights = per_sample_weights_.value();
238 } else {
239 sample_weights = create_empty_from(output, kFloat);
240 }
241
242 const std::vector<int64_t> shape = {output_size, D};
243 at::native::resize_(output, shape, std::nullopt);
244 AT_DISPATCH_INDEX_TYPES(
245 indices.scalar_type(), "embedding_bag_byte_rowwise_offsets_kernel", ([&] {
246 embedding_bag_nbits_rowwise_offsets_kernel<index_t, 8><<<
247 output_size,
248 dim3(1, 1, 1),
249 0,
250 at::cuda::getCurrentCUDAStream()>>>(
251 weight.packed_accessor64<uint8_t, 2, RestrictPtrTraits>(),
252 indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
253 offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
254 false /* pruned_weights */,
255 sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
256 compressed_indices_mapping,
257 include_last_offset,
258 output.packed_accessor32<float, 2, RestrictPtrTraits>());
259 C10_CUDA_KERNEL_LAUNCH_CHECK();
260 }));
261
262 TORCH_CHECK(output.is_cuda());
263
264 return output;
265 }
266
embedding_bag_byte_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)267 Tensor embedding_bag_byte_rowwise_offsets(
268 const Tensor& weight,
269 const Tensor& indices,
270 const std::optional<Tensor>& offsets_in,
271 const bool /* scale_grad_by_freq */,
272 const int64_t /* mode */,
273 bool pruned_weights,
274 const std::optional<Tensor>& per_sample_weights_,
275 const std::optional<Tensor>& compressed_indices_mapping,
276 bool include_last_offset) {
277 bool is_embedding_op = false;
278 auto output = create_empty_from(weight, at::kFloat);
279
280 c10::MaybeOwned<at::Tensor> offsets;
281 TORCH_CHECK(
282 indices.dim() == 1 || indices.dim() == 2,
283 "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
284 indices.dim());
285 // For embedding_bag operator with 2D indices, we set the offsets explicitly
286 // here.
287 if (indices.dim() == 2 && !is_embedding_op) {
288 TORCH_CHECK(
289 !offsets_in.has_value(),
290 "embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
291
292 offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
293 0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
294
295 } else {
296 TORCH_CHECK(
297 offsets_in.has_value(),
298 "embedding_bag_byte expects offsets to be set for 1D indices.");
299 offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
300 }
301
302 TORCH_CHECK(
303 indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
304 "Expect 32 or 64 bit indices, but found ",
305 indices.scalar_type(),
306 " instead.");
307 TORCH_CHECK(
308 offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
309 "Expect 32 or 64 bit offsets, but found ",
310 offsets->scalar_type(),
311 " instead.");
312 TORCH_CHECK(
313 weight.is_contiguous() && indices.is_contiguous() &&
314 offsets->is_contiguous(),
315 "Expect weight, indices, and offsets to be contiguous.");
316
317 // Using helper function to support different type combination without the
318 // need to cast, which can be additional performance overhead
319 if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
320 return embedding_bag_byte_impl<int, int>(
321 output,
322 weight,
323 indices,
324 *offsets,
325 pruned_weights,
326 per_sample_weights_,
327 compressed_indices_mapping,
328 include_last_offset,
329 is_embedding_op);
330 } else if (
331 indices.scalar_type() == at::kInt &&
332 offsets->scalar_type() == at::kLong) {
333 return embedding_bag_byte_impl<int, int64_t>(
334 output,
335 weight,
336 indices,
337 *offsets,
338 pruned_weights,
339 per_sample_weights_,
340 compressed_indices_mapping,
341 include_last_offset,
342 is_embedding_op);
343 } else if (
344 indices.scalar_type() == at::kLong &&
345 offsets->scalar_type() == at::kInt) {
346 return embedding_bag_byte_impl<int64_t, int>(
347 output,
348 weight,
349 indices,
350 *offsets,
351 pruned_weights,
352 per_sample_weights_,
353 compressed_indices_mapping,
354 include_last_offset,
355 is_embedding_op);
356 }
357
358 // default case given the TORCH_CHECK above
359 return embedding_bag_byte_impl<int64_t, int64_t>(
360 output,
361 weight,
362 indices,
363 *offsets,
364 pruned_weights,
365 per_sample_weights_,
366 compressed_indices_mapping,
367 include_last_offset,
368 is_embedding_op);
369 }
370
371 template <typename IndexType, typename OffsetType>
embedding_bag_4bit_impl(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset)372 at::Tensor& embedding_bag_4bit_impl(
373 at::Tensor& output,
374 const at::Tensor& weight,
375 const at::Tensor& indices,
376 const at::Tensor& offsets,
377 bool pruned_weights,
378 const std::optional<at::Tensor>& per_sample_weights_,
379 const std::optional<at::Tensor>& compressed_indices_mapping,
380 bool include_last_offset) {
381 TORCH_CHECK(weight.is_cuda());
382 TORCH_CHECK(indices.is_cuda());
383 TORCH_CHECK(offsets.is_cuda());
384 TORCH_CHECK(indices.device() == weight.device())
385 TORCH_CHECK(offsets.device() == weight.device());
386 if (per_sample_weights_.has_value()) {
387 TORCH_CHECK(per_sample_weights_.value().device() == weight.device());
388 }
389 if (compressed_indices_mapping.has_value()) {
390 TORCH_CHECK(compressed_indices_mapping.value().device() == weight.device());
391 }
392
393 TORCH_CHECK(weight.dtype() == at::kByte);
394 TORCH_CHECK(weight.dim() == 2);
395
396 at::cuda::OptionalCUDAGuard device_guard;
397 device_guard.set_index(weight.get_device());
398
399 const auto weight_sizes = weight.sizes();
400 const int64_t N = weight_sizes[0];
401 const int D = 2*(weight_sizes[1] - 4); // NB: -4 to account for scale and bias @fp16
402 const int64_t M = offsets.sizes()[0];
403 TORCH_CHECK(D % 8 == 0);
404 if(per_sample_weights_.has_value()) {
405 TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
406 "Per sample weights expected scalar type ", at::kFloat, " but got ",
407 per_sample_weights_.value().scalar_type());
408 }
409 TORCH_CHECK(
410 !compressed_indices_mapping.has_value(),
411 "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
412
413 const auto maxThreads = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
414
415 int64_t output_size = include_last_offset ? M - 1 : M;
416
417 at::Tensor sample_weights;
418 if (per_sample_weights_.has_value()) {
419 sample_weights = per_sample_weights_.value();
420 } else {
421 sample_weights = create_empty_from(output, kFloat);
422 }
423
424 const std::vector<int64_t> shape = {output_size, D};
425 at::native::resize_(output, shape, std::nullopt);
426 AT_DISPATCH_INDEX_TYPES(
427 indices.scalar_type(), "embedding_bag_4bit_rowwise_offsets_kernel", ([&] {
428 embedding_bag_nbits_rowwise_offsets_kernel<index_t, 4><<<
429 output_size,
430 dim3(1, 1, 1),
431 0,
432 at::cuda::getCurrentCUDAStream()>>>(
433 weight.packed_accessor64<uint8_t, 2, RestrictPtrTraits>(),
434 indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
435 offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
436 false /* pruned_weights */,
437 sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
438 compressed_indices_mapping,
439 include_last_offset,
440 output.packed_accessor32<float, 2, RestrictPtrTraits>());
441 C10_CUDA_KERNEL_LAUNCH_CHECK();
442 }));
443
444 TORCH_CHECK(output.is_cuda());
445
446 return output;
447 }
448
embedding_bag_4bit_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)449 Tensor embedding_bag_4bit_rowwise_offsets(
450 const Tensor& weight,
451 const Tensor& indices,
452 const std::optional<Tensor>& offsets_in,
453 const bool /* scale_grad_by_freq */,
454 const int64_t /* mode */,
455 bool pruned_weights,
456 const std::optional<Tensor>& per_sample_weights_,
457 const std::optional<Tensor>& compressed_indices_mapping,
458 bool include_last_offset) {
459 auto output = create_empty_from(weight, at::kFloat);
460
461 c10::MaybeOwned<at::Tensor> offsets;
462 TORCH_CHECK(
463 indices.dim() == 1 || indices.dim() == 2,
464 "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
465 indices.dim());
466
467 // For embedding_bag operator with 2D indices, we need to set the offsets
468 // explicitly here.
469 if (indices.dim() == 2) {
470 TORCH_CHECK(
471 !offsets_in.has_value(),
472 "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
473
474 offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
475 0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
476 } else {
477 TORCH_CHECK(
478 offsets_in.has_value(),
479 "embedding_bag_4bit operator expects offsets to be set for 1D indices.");
480 offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
481 }
482
483 TORCH_CHECK(
484 indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
485 "Expect 32 or 64 bit indices, but found ",
486 indices.scalar_type(),
487 " instead.");
488 TORCH_CHECK(
489 offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
490 "Expect 32 or 64 bit offsets, but found ",
491 offsets->scalar_type(),
492 " instead.");
493 TORCH_CHECK(
494 weight.is_contiguous() && indices.is_contiguous() &&
495 offsets->is_contiguous(),
496 "Expect weight, indices, and offsets to be contiguous.");
497
498 if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
499 return embedding_bag_4bit_impl<int, int>(
500 output,
501 weight,
502 indices,
503 *offsets,
504 pruned_weights,
505 per_sample_weights_,
506 compressed_indices_mapping,
507 include_last_offset);
508 } else if (
509 indices.scalar_type() == at::kInt &&
510 offsets->scalar_type() == at::kLong) {
511 return embedding_bag_4bit_impl<int, int64_t>(
512 output,
513 weight,
514 indices,
515 *offsets,
516 pruned_weights,
517 per_sample_weights_,
518 compressed_indices_mapping,
519 include_last_offset);
520 } else if (
521 indices.scalar_type() == at::kLong &&
522 offsets->scalar_type() == at::kInt) {
523 return embedding_bag_4bit_impl<int64_t, int>(
524 output,
525 weight,
526 indices,
527 *offsets,
528 pruned_weights,
529 per_sample_weights_,
530 compressed_indices_mapping,
531 include_last_offset);
532 }
533 return embedding_bag_4bit_impl<int64_t, int64_t>(
534 output,
535 weight,
536 indices,
537 *offsets,
538 pruned_weights,
539 per_sample_weights_,
540 compressed_indices_mapping,
541 include_last_offset);
542 }
543
qembeddingbag_4bit_unpack(const Tensor & packed_weight)544 Tensor qembeddingbag_4bit_unpack(const Tensor& packed_weight) {
545 int BIT_RATE = 4;
546 const auto input_rows = packed_weight.size(0);
547 const auto input_columns = packed_weight.size(1);
548 const auto* input_data = packed_weight.const_data_ptr<uint8_t>();
549 int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
550
551 // The last 4 bytes per row are two fp16 scale and zero_point.
552 // The rest of input_columns is the number of values in the original row.
553 std::vector<int64_t> output_dimensions = {
554 input_rows,
555 static_cast<std::int64_t>(input_columns - 2 * sizeof(at::Half)) *
556 NUM_ELEM_PER_BYTE};
557
558 auto output = at::empty(
559 output_dimensions,
560 packed_weight.options().dtype(kFloat),
561 packed_weight.suggest_memory_format());
562 return output;
563 }
564
TORCH_LIBRARY_IMPL(quantized,CUDA,m)565 TORCH_LIBRARY_IMPL(quantized, CUDA, m) {
566 m.impl(
567 TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"),
568 TORCH_FN(qembeddingbag_byte_unpack));
569 m.impl(
570 TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_rowwise_offsets"),
571 TORCH_FN(embedding_bag_byte_rowwise_offsets));
572 m.impl(
573 TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"),
574 TORCH_FN(qembeddingbag_4bit_unpack));
575 m.impl(
576 TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"),
577 TORCH_FN(embedding_bag_4bit_rowwise_offsets));
578 }
579
580 } // namespace native
581 } // namespace at
582