xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
3 
4 #include <ATen/Parallel.h>
5 #include <ATen/Utils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/custom_class.h>
8 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
9 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
10 #include <c10/core/ScalarType.h>
11 #include <torch/library.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/choose_qparams_optimized.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/resize_native.h>
20 #endif
21 
22 #include <c10/util/irange.h>
23 
24 #include <utility>
25 
26 int register_embedding_params();
27 
28 /*
29  * Prepack function for embedding_bag weights.
30  * This function expects a per-row quantized weight tensor
31  * with a floating point scale and zero_point value.
32  * zero point is set to be (-Xmin/scale)
33  * To prepack the weights we store the scale and bias (where bias is Xmin)
34  * for each row along with the quantized weights.
35  */
prepack(at::Tensor qweight)36 c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
37     at::Tensor qweight) {
38   static constexpr int64_t version = 1;
39   TORCH_CHECK(
40       qweight.dim() == 2,
41       "quantized::embedding_bag_prepack weight tensor rank should be 2");
42   TORCH_CHECK(
43       qweight.scalar_type() == c10::kQUInt8 ||
44           qweight.scalar_type() == c10::kQUInt4x2,
45       "qembedding_bag_prepack currently only supports quint8 and quint4x2 weights");
46 
47   at::Tensor weight_contig =
48       qweight.contiguous(qweight.suggest_memory_format());
49 
50   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
51   int bit_width, scale_bias_bytes;
52   uint8_t* weight_data = static_cast<uint8_t*>(weight_contig.data_ptr());
53   if (qweight.scalar_type() == c10::kQUInt8) {
54     bit_width = 8;
55     scale_bias_bytes =
56         sizeof(float) * 2; // extra 8 bytes to store FP scale and bias per row.
57   } else {
58     bit_width = 4;
59     scale_bias_bytes = sizeof(at::Half) *
60         2; // extra 4 bytes to store at::Half scale and bias per row.
61   }
62   const auto num_elem_per_byte = 8 / bit_width;
63 
64   int64_t embedding_rows = qweight.size(0);
65   int64_t embedding_cols = qweight.size(1);
66   const auto qtype = qweight.qscheme();
67   TORCH_CHECK(
68       qtype == c10::kPerChannelAffineFloatQParams,
69       "Expect embedding_bag weights to be quantized using kPerChannelAffineFloatQParams");
70   std::vector<float> weight_bias(embedding_rows);
71 
72   at::Tensor channel_scales = qweight.q_per_channel_scales();
73   at::Tensor channel_zero_points = qweight.q_per_channel_zero_points();
74   std::vector<float> weight_scales(
75       channel_scales.data_ptr<float>(),
76       channel_scales.data_ptr<float>() + embedding_rows);
77   std::vector<float> weight_zero_points(
78       channel_zero_points.data_ptr<float>(),
79       channel_zero_points.data_ptr<float>() + embedding_rows);
80 
81   for (const auto i : c10::irange(embedding_rows)) {
82     weight_bias[i] = weight_zero_points[i] * weight_scales[i] * -1;
83   }
84 
85   std::vector<int64_t> output_shape = {
86       embedding_rows,
87       static_cast<std::int64_t>(
88           (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte +
89           scale_bias_bytes)}; // extra bytes to store scale and bias per row.
90   size_t output_columns = output_shape[1];
91 
92   // Allocate output packed weights.
93   at::Tensor output = at::empty(
94       output_shape,
95       weight_contig.options().dtype(at::kByte),
96       weight_contig.suggest_memory_format());
97   auto* output_data = output.data_ptr<uint8_t>();
98 
99   if (bit_width == 8) {
100     at::parallel_for(
101         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
102           for (const auto row : c10::irange(start_idx, end_idx)) {
103             const uint8_t* input_row = weight_data + row * embedding_cols;
104             std::uint8_t* output_row = output_data + row * output_columns;
105             auto output_row_scale_bias = output_row + embedding_cols;
106             // don't use float* to avoid unaligned address access
107             std::memcpy(
108                 output_row_scale_bias, &(weight_scales[row]), sizeof(float));
109             std::memcpy(
110                 output_row_scale_bias + sizeof(float),
111                 &(weight_bias[row]),
112                 sizeof(float));
113             for (const auto col : c10::irange(embedding_cols)) {
114               output_row[col] = input_row[col];
115             }
116           }
117         });
118   } else {
119     // Re-calculate the number of embedding_cols, to account for values packed
120     // in a byte.
121     embedding_cols =
122         (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte;
123     at::parallel_for(
124         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
125           for (const auto row : c10::irange(start_idx, end_idx)) {
126             const uint8_t* input_row = weight_data + row * embedding_cols;
127             std::uint8_t* output_row = output_data + row * output_columns;
128             auto output_row_scale_bias = output_row + embedding_cols;
129             at::Half weight_scale = weight_scales[row];
130             at::Half weight_bias_half = weight_bias[row];
131             // don't use at::Half* to avoid unaligned address access
132             std::memcpy(output_row_scale_bias, &weight_scale, sizeof(at::Half));
133             std::memcpy(
134                 output_row_scale_bias + sizeof(at::Half),
135                 &weight_bias_half,
136                 sizeof(at::Half));
137 
138             for (const auto col : c10::irange(embedding_cols)) {
139               // The weight values have already been packed, so here we just
140               // store it in the output tensor.
141               output_row[col] = input_row[col];
142             }
143           }
144         });
145   }
146 
147   auto packed_ptr = c10::make_intrusive<PackedEmbeddingBagWeight>(
148       output,
149       std::move(weight_scales),
150       std::move(weight_zero_points),
151       bit_width,
152       qtype,
153       version);
154 
155   return packed_ptr;
156 }
157 
158 namespace at {
159 namespace native {
160 
161 // Note - This is a temporary pack function for embedding bag which quantizes
162 // and packs the float weight tensor. In the next step it will be replaced by a
163 // quantize and pack function once we support FP scale and FP zero_point
164 //
165 // Python example examining a packed 8bit zero_point and scale:
166 //
167 // >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
168 // dtype=np.float32))
169 // >> x_packed = torch.ops.quantized.embedding_bag_byte_prepack(x)
170 //
171 // # Pull out and examine packed scales, zero_points and values
172 // >> zero_points = x_packed[:,:,-4:].numpy()
173 // >> scales = x_packed[:,:,-8:-4].numpy()
174 // >> values = x_packed[:,:,:-8].numpy()
175 //
176 // >> zero_points
177 // array([[[  0,   0,  32,  65],
178 //        [  0,   0, 240,  65]],
179 //
180 //       [[  0,   0,  72,  66],
181 //        [  0,   0, 140,  66]]], dtype=uint8)
182 //
183 // >> scales
184 // array([[[161, 160,  32,  61],
185 //        [161, 160,  32,  61]],
186 //
187 //       [[161, 160,  32,  61],
188 //        [161, 160,  32,  61]]], dtype=uint8)
189 // >> values
190 // array([[[  0, 255],
191 //        [  0, 255]],
192 //
193 //       [[  0, 255],
194 //        [  0, 255]]], dtype=uint8)
195 //
196 // # Convert 4 byte packed scales and zero_points to float
197 // # and apply against values in order to recover unquantized values.
198 // def bytes2float(arr):
199 //    packed_hex = bytearray(arr)
200 //    return struct.unpack('f', packed_hex)
201 //
202 // >> float_zero_points = np.apply_along_axis(bytes2float, 2, zero_points)
203 // >> float_zero_points
204 // array([[[10.],
205 //         [30.]],
206 //
207 //        [[50.],
208 //         [70.]]])
209 // >> float_scales = np.apply_along_axis(bytes2float, 2, scales)
210 // >> float_scales
211 // array([[[0.03921569],
212 //        [0.03921569]],
213 //
214 //       [[0.03921569],
215 //        [0.03921569]]])
216 // >> values *  float_scales + float_zero_points
217 // array([[[10.        , 20.00000035],
218 //         [30.        , 40.00000035]],
219 //
220 //        [[50.        , 60.00000035],
221 //         [70.        , 80.00000035]]])
qembeddingbag_byte_prepack_out(Tensor & output,const Tensor & weight)222 Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
223   // The "last" dimension of an N-Dimensioned batch of embedding bags is
224   // quantization channel. E.g. for a 2D embedding bag, this has
225   // [ row, col ] dimensions, for batched of embedding bags, dimensions might be
226   // [ batch, row, col ].
227   //
228   // Python Batched Embedding Example:
229   // weights = torch.from_numpy((np.random.random_sample((
230   //          2, 10, 3)).squeeze() + 1).astype(np.float32))
231   // assert(weights.size() == torch.Size([2, 10, 3]))
232   // # NOTE: 8 bytes (columns) are added due to fp32 zero_point and scales
233   // packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
234   // assert(packed_weights.size() == torch.Size([2, 10, 11]))
235 
236   TORCH_CHECK(
237       weight.scalar_type() == at::ScalarType::Float ||
238           weight.scalar_type() == at::ScalarType::Half,
239       "'embedding_bag_byte_prepack' only support float32 or float16.");
240 
241   const auto weight_sizes = weight.sizes();
242   const auto cols_dim = weight_sizes.size() - 1;
243   const int64_t embedding_rows = c10::size_to_dim_(cols_dim, weight_sizes);
244   const int32_t embedding_cols = weight_sizes[cols_dim];
245   // Add 8 bytes per column to store FP32 scale and zero_point per row.
246   const int32_t output_columns = embedding_cols + 2 * sizeof(float);
247   const auto weight_contig =
248       weight.expect_contiguous(weight.suggest_memory_format());
249 
250   // Adjust output dimensions to account for FP32 scale and zero_points.
251   std::vector<int64_t> output_shape = weight_sizes.vec();
252   output_shape[cols_dim] = output_columns;
253   at::native::resize_(output, output_shape, std::nullopt);
254   auto* output_data = output.data_ptr<uint8_t>();
255 
256 #ifdef USE_FBGEMM
257   if (weight_contig->scalar_type() == at::ScalarType::Half) {
258     const auto weight_data =
259         static_cast<fbgemm::float16*>(weight_contig->data_ptr());
260     at::parallel_for(
261         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
262           fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
263               fbgemm::float16>(
264               weight_data + start_idx * embedding_cols,
265               end_idx - start_idx,
266               embedding_cols,
267               output_data + start_idx * output_columns);
268         });
269   } else {
270     const auto weight_data = weight_contig->data_ptr<float>();
271     at::parallel_for(
272         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
273           fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
274               weight_data + start_idx * embedding_cols,
275               end_idx - start_idx,
276               embedding_cols,
277               output_data + start_idx * output_columns);
278         });
279   }
280 
281 #else
282   const Tensor& float_weight =
283       weight_contig->scalar_type() == at::ScalarType::Half
284       ? weight_contig->to(at::ScalarType::Float)
285       : *weight_contig;
286   const auto weight_data = float_weight.data_ptr<float>();
287   constexpr float kEpsilon = 1e-8f;
288   for (auto row : c10::irange(embedding_rows)) {
289     const float* input_row = weight_data + row * embedding_cols;
290     std::uint8_t* output_row = output_data + row * output_columns;
291     float* output_row_scale_zp =
292         reinterpret_cast<float*>(output_row + embedding_cols);
293 
294     float minimum_element =
295         *std::min_element(input_row, input_row + embedding_cols);
296     float maximum_element =
297         *std::max_element(input_row, input_row + embedding_cols);
298     float range = maximum_element - minimum_element;
299 
300     output_row_scale_zp[0] = range / 255.0f;
301     output_row_scale_zp[1] = minimum_element;
302     const auto inverse_scale = 255.0f / (range + kEpsilon);
303     for (auto col : c10::irange(embedding_cols)) {
304       output_row[col] =
305           lrintf((input_row[col] - minimum_element) * inverse_scale);
306     } // embedding_cols
307   } // embedding_rows
308 #endif // USE_FBGEMM
309 
310   return output;
311 }
312 
qembeddingbag_byte_prepack(const Tensor & weight)313 Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
314   const auto weight_contig =
315       weight.expect_contiguous(weight.suggest_memory_format());
316   Tensor output = at::detail::empty_cpu(
317       {0},
318       at::kByte,
319       weight_contig->layout(),
320       weight_contig->device(),
321       std::nullopt,
322       std::nullopt);
323   qembeddingbag_byte_prepack_out(output, weight);
324   return output;
325 }
326 
qembeddingbag_byte_prepack_meta(const Tensor & weight)327 Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
328   const auto weight_contig =
329       weight.expect_contiguous(weight.suggest_memory_format());
330   TORCH_CHECK(
331       weight.scalar_type() == at::ScalarType::Float ||
332           weight.scalar_type() == at::ScalarType::Half,
333       "'embedding_bag_byte_prepack' only support float32 or float16.");
334   const auto weight_sizes = weight.sizes();
335   const auto cols_dim = weight_sizes.size() - 1;
336   const int32_t embedding_cols = weight_sizes[cols_dim];
337   // Add 8 bytes per column to store FP32 scale and zero_point per row.
338   const int32_t output_columns = embedding_cols + 2 * sizeof(float);
339 
340   // Adjust output dimensions to account for FP32 scale and zero_points.
341   std::vector<int64_t> output_shape = weight_sizes.vec();
342   output_shape[cols_dim] = output_columns;
343   at::SymDimVector output_shape_vec(output_shape);
344 
345   return at::empty_symint(
346       output_shape_vec,
347       weight.options().dtype(weight.scalar_type()),
348       weight.suggest_memory_format());
349 }
350 
351 namespace {
352 
353 // TODO: Extend support to N-D batched embeddings, similar to
354 // qembeddingbag_byte_prepack
_qembeddingbag_nbit_prepack_helper(const Tensor & weight,int bit_width,const bool optimized_qparams,const int64_t nbins,const double ratio)355 Tensor _qembeddingbag_nbit_prepack_helper(
356     const Tensor& weight,
357     int bit_width,
358     const bool optimized_qparams,
359     const int64_t nbins,
360     const double ratio) {
361   TORCH_CHECK(
362       weight.scalar_type() == at::ScalarType::Float ||
363           weight.scalar_type() == at::ScalarType::Half,
364       "'qembeddingbag_nbit_prepack' only support float32 or float16.");
365 
366   int64_t embedding_rows = weight.size(0);
367   int64_t embedding_cols = weight.size(1);
368 
369   Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
370 
371   TORCH_CHECK(
372       bit_width == 4 || bit_width == 2,
373       "bit_width must be either 2 or 4 to use 'qembeddingbag_nbit_prepack'."
374       "For 8bit, consider using 'embedding_bag_byte_prepack'.");
375 
376   int NUM_ELEM_PER_BYTE = 8 / bit_width;
377   TORCH_CHECK(
378       weight_contig.size(weight.dim() - 1) % NUM_ELEM_PER_BYTE == 0,
379       "qembeddingbag_",
380       std::to_string(bit_width),
381       "bit_prepack only works for the number of columns a multiple of ",
382       std::to_string(NUM_ELEM_PER_BYTE));
383 
384   // The "fused" representation stores the scale and bias with the
385   // row-wise quantized data in one tensor.
386   // Since we represent the scale and bias in 16-bit float, we'll use the
387   // last 4 bytes of each row for scale (2 bytes) and bias (2 bytes).
388   // | ... quantized data ... | scale | bias |
389   // |    number_of_columns   |  2B   |  2B  |
390   std::vector<int64_t> output_shape = {
391       embedding_rows,
392       static_cast<std::int64_t>(
393           (embedding_cols + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE +
394           2 * sizeof(at::Half))};
395   auto output = at::empty(
396       output_shape,
397       weight_contig.options().dtype(at::kByte),
398       weight_contig.suggest_memory_format());
399   auto* output_data = output.data_ptr<uint8_t>();
400 
401 #ifdef USE_FBGEMM
402   if (!optimized_qparams) {
403     if (weight_contig.scalar_type() == at::ScalarType::Half) {
404       const auto weight_data =
405           static_cast<fbgemm::float16*>(weight_contig.data_ptr());
406       at::parallel_for(
407           0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
408             fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
409                 fbgemm::float16>(
410                 bit_width,
411                 weight_data + start_idx * embedding_cols,
412                 end_idx - start_idx,
413                 embedding_cols,
414                 output_data + start_idx * output_shape[1]);
415           });
416     } else {
417       const auto weight_data = weight_contig.data_ptr<float>();
418       at::parallel_for(
419           0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
420             fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
421                 bit_width,
422                 weight_data + start_idx * embedding_cols,
423                 end_idx - start_idx,
424                 embedding_cols,
425                 output_data + start_idx * output_shape[1]);
426           });
427     }
428   } else {
429 #endif // USE_FBGEMM
430     const auto output_columns = output.size(output.dim() - 1);
431     const auto float_weight =
432         weight_contig.scalar_type() == at::ScalarType::Half
433         ? weight_contig.to(at::ScalarType::Float)
434         : std::move(weight_contig);
435     const auto weight_data = float_weight.data_ptr<float>();
436     for (const auto row : c10::irange(embedding_rows)) {
437       const float* input_row = weight_data + row * embedding_cols;
438       std::uint8_t* output_row = output_data + row * output_columns;
439 
440       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
441       float Xmin, Xmax;
442       if (optimized_qparams) {
443         auto [xmax_tensor, xmin_tensor] = at::choose_qparams_optimized(
444             float_weight[row], embedding_cols, nbins, ratio, bit_width);
445         TORCH_CHECK(
446             xmax_tensor.numel() == 1 && xmin_tensor.numel() == 1,
447             "Expected choose_qparams_optimized to return min/max tensors of size 1");
448         Xmax = xmax_tensor.item<float>();
449         Xmin = xmin_tensor.item<float>();
450       } else {
451         Xmin = *std::min_element(input_row, input_row + embedding_cols);
452         Xmax = *std::max_element(input_row, input_row + embedding_cols);
453       }
454       Xmin = static_cast<at::Half>(Xmin);
455       float range = Xmax - Xmin;
456       // Set scale to 1.0f for the corner case of Xmax == Xmin .
457       // Any non-zero scale would work because during quantization
458       // (X - Xmin) / scale will be 0 for all X unless scale is 0.
459       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
460       at::Half scale = range == 0 ? 1.0f : range / ((1 << bit_width) - 1);
461       float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale;
462       if (scale == 0 || std::isinf(inverse_scale)) {
463         // Corner case handling when Xmax == Xmin
464         // Any scale would work because X - Xmin will be 0 for all X
465         scale = 1.0f;
466         inverse_scale = 1.0f;
467       }
468       // Update the scale and zero_point of each row.
469       at::Half* output_row_scale_zp = reinterpret_cast<at::Half*>(
470           output_row +
471           (embedding_cols + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
472 
473       output_row_scale_zp[0] = scale;
474       output_row_scale_zp[1] = Xmin;
475 
476       // Pack the weight values.
477       for (const auto col : c10::irange(embedding_cols)) {
478         float X = input_row[col];
479         std::uint8_t quantized = std::max(
480             0,
481             std::min<int>(
482                 lrintf((X - Xmin) * inverse_scale), (1 << bit_width) - 1));
483         // We pack 2 4-bit values in a byte. Index 0 is packed in the lower
484         // 4-bits and index 1 is packed in the upper 4-bits.
485         if (col % NUM_ELEM_PER_BYTE == 0) {
486           output_row[col / NUM_ELEM_PER_BYTE] = quantized;
487         } else {
488           output_row[col / NUM_ELEM_PER_BYTE] |=
489               (quantized << ((col % NUM_ELEM_PER_BYTE) * bit_width));
490         }
491       } // embedding_cols
492     } // embedding_rows
493 #ifdef USE_FBGEMM
494   }
495 #endif // USE_FBGEMM
496 
497   return output;
498 }
499 
500 // Applies 4-bit row-wise quantization by determining the range
501 // (maximum - minimum) and bias (minimum value) of each row in the input
502 // matrix, and then scaling each element to an 2-bit number between 0 and
503 // 15.
504 // To later de-quantize values, the scale (range / 15) and zero_point
505 // are stored alongside the data. More precisely, each row first has quantized
506 // values, and then 2-byte fp16 scale and 2-byte zero_offset.
qembeddingbag_4bit_prepack(const Tensor & weight,const bool optimized_qparams,const int64_t nbins,const double ratio)507 Tensor qembeddingbag_4bit_prepack(
508     const Tensor& weight,
509     const bool optimized_qparams,
510     const int64_t nbins,
511     const double ratio) {
512   return _qembeddingbag_nbit_prepack_helper(
513       weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio);
514 }
515 
516 // Applies 2-bit row-wise quantization by determining the range
517 // (maximum - minimum) and bias (minimum value) of each row in the input
518 // matrix, and then scaling each element to an 2-bit number between 0 and
519 // 3.
520 // To later de-quantize values, the scale (range / 3) and zero_point
521 // are stored alongside the data. More precisely, each row first has quantized
522 // values, and then 2-byte fp16 scale and 2-byte zero_offset.
523 // TODO() - Add 2Bit Embedding Lookup operator.
qembeddingbag_2bit_prepack(const Tensor & weight,const bool optimized_qparams,const int64_t nbins,const double ratio)524 Tensor qembeddingbag_2bit_prepack(
525     const Tensor& weight,
526     const bool optimized_qparams,
527     const int64_t nbins,
528     const double ratio) {
529   return _qembeddingbag_nbit_prepack_helper(
530       weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio);
531 }
532 
533 class QEmbeddingPackWeights final {
534  public:
run(at::Tensor weight)535   static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(at::Tensor weight) {
536     return PackedEmbeddingBagWeight::prepack(std::move(weight));
537   }
538 };
539 
TORCH_LIBRARY_IMPL(quantized,CPU,m)540 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
541   m.impl(
542       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
543       TORCH_FN(qembeddingbag_byte_prepack));
544   m.impl(
545       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
546       TORCH_FN(qembeddingbag_4bit_prepack));
547   m.impl(
548       TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
549       TORCH_FN(qembeddingbag_2bit_prepack));
550 }
551 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)552 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
553   m.impl(
554       TORCH_SELECTIVE_NAME("quantized::embedding_bag_prepack"),
555       TORCH_FN(QEmbeddingPackWeights::run));
556 }
557 
TORCH_LIBRARY_IMPL(quantized,Meta,m)558 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
559   m.impl(
560       "quantized::embedding_bag_byte_prepack", qembeddingbag_byte_prepack_meta);
561 }
562 
563 } // namespace
564 } // namespace native
565 } // namespace at
566