1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
3
4 #include <ATen/Parallel.h>
5 #include <ATen/Utils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/custom_class.h>
8 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
9 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
10 #include <c10/core/ScalarType.h>
11 #include <torch/library.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/choose_qparams_optimized.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/resize_native.h>
20 #endif
21
22 #include <c10/util/irange.h>
23
24 #include <utility>
25
26 int register_embedding_params();
27
28 /*
29 * Prepack function for embedding_bag weights.
30 * This function expects a per-row quantized weight tensor
31 * with a floating point scale and zero_point value.
32 * zero point is set to be (-Xmin/scale)
33 * To prepack the weights we store the scale and bias (where bias is Xmin)
34 * for each row along with the quantized weights.
35 */
prepack(at::Tensor qweight)36 c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
37 at::Tensor qweight) {
38 static constexpr int64_t version = 1;
39 TORCH_CHECK(
40 qweight.dim() == 2,
41 "quantized::embedding_bag_prepack weight tensor rank should be 2");
42 TORCH_CHECK(
43 qweight.scalar_type() == c10::kQUInt8 ||
44 qweight.scalar_type() == c10::kQUInt4x2,
45 "qembedding_bag_prepack currently only supports quint8 and quint4x2 weights");
46
47 at::Tensor weight_contig =
48 qweight.contiguous(qweight.suggest_memory_format());
49
50 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
51 int bit_width, scale_bias_bytes;
52 uint8_t* weight_data = static_cast<uint8_t*>(weight_contig.data_ptr());
53 if (qweight.scalar_type() == c10::kQUInt8) {
54 bit_width = 8;
55 scale_bias_bytes =
56 sizeof(float) * 2; // extra 8 bytes to store FP scale and bias per row.
57 } else {
58 bit_width = 4;
59 scale_bias_bytes = sizeof(at::Half) *
60 2; // extra 4 bytes to store at::Half scale and bias per row.
61 }
62 const auto num_elem_per_byte = 8 / bit_width;
63
64 int64_t embedding_rows = qweight.size(0);
65 int64_t embedding_cols = qweight.size(1);
66 const auto qtype = qweight.qscheme();
67 TORCH_CHECK(
68 qtype == c10::kPerChannelAffineFloatQParams,
69 "Expect embedding_bag weights to be quantized using kPerChannelAffineFloatQParams");
70 std::vector<float> weight_bias(embedding_rows);
71
72 at::Tensor channel_scales = qweight.q_per_channel_scales();
73 at::Tensor channel_zero_points = qweight.q_per_channel_zero_points();
74 std::vector<float> weight_scales(
75 channel_scales.data_ptr<float>(),
76 channel_scales.data_ptr<float>() + embedding_rows);
77 std::vector<float> weight_zero_points(
78 channel_zero_points.data_ptr<float>(),
79 channel_zero_points.data_ptr<float>() + embedding_rows);
80
81 for (const auto i : c10::irange(embedding_rows)) {
82 weight_bias[i] = weight_zero_points[i] * weight_scales[i] * -1;
83 }
84
85 std::vector<int64_t> output_shape = {
86 embedding_rows,
87 static_cast<std::int64_t>(
88 (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte +
89 scale_bias_bytes)}; // extra bytes to store scale and bias per row.
90 size_t output_columns = output_shape[1];
91
92 // Allocate output packed weights.
93 at::Tensor output = at::empty(
94 output_shape,
95 weight_contig.options().dtype(at::kByte),
96 weight_contig.suggest_memory_format());
97 auto* output_data = output.data_ptr<uint8_t>();
98
99 if (bit_width == 8) {
100 at::parallel_for(
101 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
102 for (const auto row : c10::irange(start_idx, end_idx)) {
103 const uint8_t* input_row = weight_data + row * embedding_cols;
104 std::uint8_t* output_row = output_data + row * output_columns;
105 auto output_row_scale_bias = output_row + embedding_cols;
106 // don't use float* to avoid unaligned address access
107 std::memcpy(
108 output_row_scale_bias, &(weight_scales[row]), sizeof(float));
109 std::memcpy(
110 output_row_scale_bias + sizeof(float),
111 &(weight_bias[row]),
112 sizeof(float));
113 for (const auto col : c10::irange(embedding_cols)) {
114 output_row[col] = input_row[col];
115 }
116 }
117 });
118 } else {
119 // Re-calculate the number of embedding_cols, to account for values packed
120 // in a byte.
121 embedding_cols =
122 (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte;
123 at::parallel_for(
124 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
125 for (const auto row : c10::irange(start_idx, end_idx)) {
126 const uint8_t* input_row = weight_data + row * embedding_cols;
127 std::uint8_t* output_row = output_data + row * output_columns;
128 auto output_row_scale_bias = output_row + embedding_cols;
129 at::Half weight_scale = weight_scales[row];
130 at::Half weight_bias_half = weight_bias[row];
131 // don't use at::Half* to avoid unaligned address access
132 std::memcpy(output_row_scale_bias, &weight_scale, sizeof(at::Half));
133 std::memcpy(
134 output_row_scale_bias + sizeof(at::Half),
135 &weight_bias_half,
136 sizeof(at::Half));
137
138 for (const auto col : c10::irange(embedding_cols)) {
139 // The weight values have already been packed, so here we just
140 // store it in the output tensor.
141 output_row[col] = input_row[col];
142 }
143 }
144 });
145 }
146
147 auto packed_ptr = c10::make_intrusive<PackedEmbeddingBagWeight>(
148 output,
149 std::move(weight_scales),
150 std::move(weight_zero_points),
151 bit_width,
152 qtype,
153 version);
154
155 return packed_ptr;
156 }
157
158 namespace at {
159 namespace native {
160
161 // Note - This is a temporary pack function for embedding bag which quantizes
162 // and packs the float weight tensor. In the next step it will be replaced by a
163 // quantize and pack function once we support FP scale and FP zero_point
164 //
165 // Python example examining a packed 8bit zero_point and scale:
166 //
167 // >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
168 // dtype=np.float32))
169 // >> x_packed = torch.ops.quantized.embedding_bag_byte_prepack(x)
170 //
171 // # Pull out and examine packed scales, zero_points and values
172 // >> zero_points = x_packed[:,:,-4:].numpy()
173 // >> scales = x_packed[:,:,-8:-4].numpy()
174 // >> values = x_packed[:,:,:-8].numpy()
175 //
176 // >> zero_points
177 // array([[[ 0, 0, 32, 65],
178 // [ 0, 0, 240, 65]],
179 //
180 // [[ 0, 0, 72, 66],
181 // [ 0, 0, 140, 66]]], dtype=uint8)
182 //
183 // >> scales
184 // array([[[161, 160, 32, 61],
185 // [161, 160, 32, 61]],
186 //
187 // [[161, 160, 32, 61],
188 // [161, 160, 32, 61]]], dtype=uint8)
189 // >> values
190 // array([[[ 0, 255],
191 // [ 0, 255]],
192 //
193 // [[ 0, 255],
194 // [ 0, 255]]], dtype=uint8)
195 //
196 // # Convert 4 byte packed scales and zero_points to float
197 // # and apply against values in order to recover unquantized values.
198 // def bytes2float(arr):
199 // packed_hex = bytearray(arr)
200 // return struct.unpack('f', packed_hex)
201 //
202 // >> float_zero_points = np.apply_along_axis(bytes2float, 2, zero_points)
203 // >> float_zero_points
204 // array([[[10.],
205 // [30.]],
206 //
207 // [[50.],
208 // [70.]]])
209 // >> float_scales = np.apply_along_axis(bytes2float, 2, scales)
210 // >> float_scales
211 // array([[[0.03921569],
212 // [0.03921569]],
213 //
214 // [[0.03921569],
215 // [0.03921569]]])
216 // >> values * float_scales + float_zero_points
217 // array([[[10. , 20.00000035],
218 // [30. , 40.00000035]],
219 //
220 // [[50. , 60.00000035],
221 // [70. , 80.00000035]]])
qembeddingbag_byte_prepack_out(Tensor & output,const Tensor & weight)222 Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
223 // The "last" dimension of an N-Dimensioned batch of embedding bags is
224 // quantization channel. E.g. for a 2D embedding bag, this has
225 // [ row, col ] dimensions, for batched of embedding bags, dimensions might be
226 // [ batch, row, col ].
227 //
228 // Python Batched Embedding Example:
229 // weights = torch.from_numpy((np.random.random_sample((
230 // 2, 10, 3)).squeeze() + 1).astype(np.float32))
231 // assert(weights.size() == torch.Size([2, 10, 3]))
232 // # NOTE: 8 bytes (columns) are added due to fp32 zero_point and scales
233 // packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
234 // assert(packed_weights.size() == torch.Size([2, 10, 11]))
235
236 TORCH_CHECK(
237 weight.scalar_type() == at::ScalarType::Float ||
238 weight.scalar_type() == at::ScalarType::Half,
239 "'embedding_bag_byte_prepack' only support float32 or float16.");
240
241 const auto weight_sizes = weight.sizes();
242 const auto cols_dim = weight_sizes.size() - 1;
243 const int64_t embedding_rows = c10::size_to_dim_(cols_dim, weight_sizes);
244 const int32_t embedding_cols = weight_sizes[cols_dim];
245 // Add 8 bytes per column to store FP32 scale and zero_point per row.
246 const int32_t output_columns = embedding_cols + 2 * sizeof(float);
247 const auto weight_contig =
248 weight.expect_contiguous(weight.suggest_memory_format());
249
250 // Adjust output dimensions to account for FP32 scale and zero_points.
251 std::vector<int64_t> output_shape = weight_sizes.vec();
252 output_shape[cols_dim] = output_columns;
253 at::native::resize_(output, output_shape, std::nullopt);
254 auto* output_data = output.data_ptr<uint8_t>();
255
256 #ifdef USE_FBGEMM
257 if (weight_contig->scalar_type() == at::ScalarType::Half) {
258 const auto weight_data =
259 static_cast<fbgemm::float16*>(weight_contig->data_ptr());
260 at::parallel_for(
261 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
262 fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
263 fbgemm::float16>(
264 weight_data + start_idx * embedding_cols,
265 end_idx - start_idx,
266 embedding_cols,
267 output_data + start_idx * output_columns);
268 });
269 } else {
270 const auto weight_data = weight_contig->data_ptr<float>();
271 at::parallel_for(
272 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
273 fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
274 weight_data + start_idx * embedding_cols,
275 end_idx - start_idx,
276 embedding_cols,
277 output_data + start_idx * output_columns);
278 });
279 }
280
281 #else
282 const Tensor& float_weight =
283 weight_contig->scalar_type() == at::ScalarType::Half
284 ? weight_contig->to(at::ScalarType::Float)
285 : *weight_contig;
286 const auto weight_data = float_weight.data_ptr<float>();
287 constexpr float kEpsilon = 1e-8f;
288 for (auto row : c10::irange(embedding_rows)) {
289 const float* input_row = weight_data + row * embedding_cols;
290 std::uint8_t* output_row = output_data + row * output_columns;
291 float* output_row_scale_zp =
292 reinterpret_cast<float*>(output_row + embedding_cols);
293
294 float minimum_element =
295 *std::min_element(input_row, input_row + embedding_cols);
296 float maximum_element =
297 *std::max_element(input_row, input_row + embedding_cols);
298 float range = maximum_element - minimum_element;
299
300 output_row_scale_zp[0] = range / 255.0f;
301 output_row_scale_zp[1] = minimum_element;
302 const auto inverse_scale = 255.0f / (range + kEpsilon);
303 for (auto col : c10::irange(embedding_cols)) {
304 output_row[col] =
305 lrintf((input_row[col] - minimum_element) * inverse_scale);
306 } // embedding_cols
307 } // embedding_rows
308 #endif // USE_FBGEMM
309
310 return output;
311 }
312
qembeddingbag_byte_prepack(const Tensor & weight)313 Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
314 const auto weight_contig =
315 weight.expect_contiguous(weight.suggest_memory_format());
316 Tensor output = at::detail::empty_cpu(
317 {0},
318 at::kByte,
319 weight_contig->layout(),
320 weight_contig->device(),
321 std::nullopt,
322 std::nullopt);
323 qembeddingbag_byte_prepack_out(output, weight);
324 return output;
325 }
326
qembeddingbag_byte_prepack_meta(const Tensor & weight)327 Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
328 const auto weight_contig =
329 weight.expect_contiguous(weight.suggest_memory_format());
330 TORCH_CHECK(
331 weight.scalar_type() == at::ScalarType::Float ||
332 weight.scalar_type() == at::ScalarType::Half,
333 "'embedding_bag_byte_prepack' only support float32 or float16.");
334 const auto weight_sizes = weight.sizes();
335 const auto cols_dim = weight_sizes.size() - 1;
336 const int32_t embedding_cols = weight_sizes[cols_dim];
337 // Add 8 bytes per column to store FP32 scale and zero_point per row.
338 const int32_t output_columns = embedding_cols + 2 * sizeof(float);
339
340 // Adjust output dimensions to account for FP32 scale and zero_points.
341 std::vector<int64_t> output_shape = weight_sizes.vec();
342 output_shape[cols_dim] = output_columns;
343 at::SymDimVector output_shape_vec(output_shape);
344
345 return at::empty_symint(
346 output_shape_vec,
347 weight.options().dtype(weight.scalar_type()),
348 weight.suggest_memory_format());
349 }
350
351 namespace {
352
353 // TODO: Extend support to N-D batched embeddings, similar to
354 // qembeddingbag_byte_prepack
_qembeddingbag_nbit_prepack_helper(const Tensor & weight,int bit_width,const bool optimized_qparams,const int64_t nbins,const double ratio)355 Tensor _qembeddingbag_nbit_prepack_helper(
356 const Tensor& weight,
357 int bit_width,
358 const bool optimized_qparams,
359 const int64_t nbins,
360 const double ratio) {
361 TORCH_CHECK(
362 weight.scalar_type() == at::ScalarType::Float ||
363 weight.scalar_type() == at::ScalarType::Half,
364 "'qembeddingbag_nbit_prepack' only support float32 or float16.");
365
366 int64_t embedding_rows = weight.size(0);
367 int64_t embedding_cols = weight.size(1);
368
369 Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
370
371 TORCH_CHECK(
372 bit_width == 4 || bit_width == 2,
373 "bit_width must be either 2 or 4 to use 'qembeddingbag_nbit_prepack'."
374 "For 8bit, consider using 'embedding_bag_byte_prepack'.");
375
376 int NUM_ELEM_PER_BYTE = 8 / bit_width;
377 TORCH_CHECK(
378 weight_contig.size(weight.dim() - 1) % NUM_ELEM_PER_BYTE == 0,
379 "qembeddingbag_",
380 std::to_string(bit_width),
381 "bit_prepack only works for the number of columns a multiple of ",
382 std::to_string(NUM_ELEM_PER_BYTE));
383
384 // The "fused" representation stores the scale and bias with the
385 // row-wise quantized data in one tensor.
386 // Since we represent the scale and bias in 16-bit float, we'll use the
387 // last 4 bytes of each row for scale (2 bytes) and bias (2 bytes).
388 // | ... quantized data ... | scale | bias |
389 // | number_of_columns | 2B | 2B |
390 std::vector<int64_t> output_shape = {
391 embedding_rows,
392 static_cast<std::int64_t>(
393 (embedding_cols + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE +
394 2 * sizeof(at::Half))};
395 auto output = at::empty(
396 output_shape,
397 weight_contig.options().dtype(at::kByte),
398 weight_contig.suggest_memory_format());
399 auto* output_data = output.data_ptr<uint8_t>();
400
401 #ifdef USE_FBGEMM
402 if (!optimized_qparams) {
403 if (weight_contig.scalar_type() == at::ScalarType::Half) {
404 const auto weight_data =
405 static_cast<fbgemm::float16*>(weight_contig.data_ptr());
406 at::parallel_for(
407 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
408 fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
409 fbgemm::float16>(
410 bit_width,
411 weight_data + start_idx * embedding_cols,
412 end_idx - start_idx,
413 embedding_cols,
414 output_data + start_idx * output_shape[1]);
415 });
416 } else {
417 const auto weight_data = weight_contig.data_ptr<float>();
418 at::parallel_for(
419 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
420 fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
421 bit_width,
422 weight_data + start_idx * embedding_cols,
423 end_idx - start_idx,
424 embedding_cols,
425 output_data + start_idx * output_shape[1]);
426 });
427 }
428 } else {
429 #endif // USE_FBGEMM
430 const auto output_columns = output.size(output.dim() - 1);
431 const auto float_weight =
432 weight_contig.scalar_type() == at::ScalarType::Half
433 ? weight_contig.to(at::ScalarType::Float)
434 : std::move(weight_contig);
435 const auto weight_data = float_weight.data_ptr<float>();
436 for (const auto row : c10::irange(embedding_rows)) {
437 const float* input_row = weight_data + row * embedding_cols;
438 std::uint8_t* output_row = output_data + row * output_columns;
439
440 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
441 float Xmin, Xmax;
442 if (optimized_qparams) {
443 auto [xmax_tensor, xmin_tensor] = at::choose_qparams_optimized(
444 float_weight[row], embedding_cols, nbins, ratio, bit_width);
445 TORCH_CHECK(
446 xmax_tensor.numel() == 1 && xmin_tensor.numel() == 1,
447 "Expected choose_qparams_optimized to return min/max tensors of size 1");
448 Xmax = xmax_tensor.item<float>();
449 Xmin = xmin_tensor.item<float>();
450 } else {
451 Xmin = *std::min_element(input_row, input_row + embedding_cols);
452 Xmax = *std::max_element(input_row, input_row + embedding_cols);
453 }
454 Xmin = static_cast<at::Half>(Xmin);
455 float range = Xmax - Xmin;
456 // Set scale to 1.0f for the corner case of Xmax == Xmin .
457 // Any non-zero scale would work because during quantization
458 // (X - Xmin) / scale will be 0 for all X unless scale is 0.
459 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
460 at::Half scale = range == 0 ? 1.0f : range / ((1 << bit_width) - 1);
461 float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale;
462 if (scale == 0 || std::isinf(inverse_scale)) {
463 // Corner case handling when Xmax == Xmin
464 // Any scale would work because X - Xmin will be 0 for all X
465 scale = 1.0f;
466 inverse_scale = 1.0f;
467 }
468 // Update the scale and zero_point of each row.
469 at::Half* output_row_scale_zp = reinterpret_cast<at::Half*>(
470 output_row +
471 (embedding_cols + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
472
473 output_row_scale_zp[0] = scale;
474 output_row_scale_zp[1] = Xmin;
475
476 // Pack the weight values.
477 for (const auto col : c10::irange(embedding_cols)) {
478 float X = input_row[col];
479 std::uint8_t quantized = std::max(
480 0,
481 std::min<int>(
482 lrintf((X - Xmin) * inverse_scale), (1 << bit_width) - 1));
483 // We pack 2 4-bit values in a byte. Index 0 is packed in the lower
484 // 4-bits and index 1 is packed in the upper 4-bits.
485 if (col % NUM_ELEM_PER_BYTE == 0) {
486 output_row[col / NUM_ELEM_PER_BYTE] = quantized;
487 } else {
488 output_row[col / NUM_ELEM_PER_BYTE] |=
489 (quantized << ((col % NUM_ELEM_PER_BYTE) * bit_width));
490 }
491 } // embedding_cols
492 } // embedding_rows
493 #ifdef USE_FBGEMM
494 }
495 #endif // USE_FBGEMM
496
497 return output;
498 }
499
500 // Applies 4-bit row-wise quantization by determining the range
501 // (maximum - minimum) and bias (minimum value) of each row in the input
502 // matrix, and then scaling each element to an 2-bit number between 0 and
503 // 15.
504 // To later de-quantize values, the scale (range / 15) and zero_point
505 // are stored alongside the data. More precisely, each row first has quantized
506 // values, and then 2-byte fp16 scale and 2-byte zero_offset.
qembeddingbag_4bit_prepack(const Tensor & weight,const bool optimized_qparams,const int64_t nbins,const double ratio)507 Tensor qembeddingbag_4bit_prepack(
508 const Tensor& weight,
509 const bool optimized_qparams,
510 const int64_t nbins,
511 const double ratio) {
512 return _qembeddingbag_nbit_prepack_helper(
513 weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio);
514 }
515
516 // Applies 2-bit row-wise quantization by determining the range
517 // (maximum - minimum) and bias (minimum value) of each row in the input
518 // matrix, and then scaling each element to an 2-bit number between 0 and
519 // 3.
520 // To later de-quantize values, the scale (range / 3) and zero_point
521 // are stored alongside the data. More precisely, each row first has quantized
522 // values, and then 2-byte fp16 scale and 2-byte zero_offset.
523 // TODO() - Add 2Bit Embedding Lookup operator.
qembeddingbag_2bit_prepack(const Tensor & weight,const bool optimized_qparams,const int64_t nbins,const double ratio)524 Tensor qembeddingbag_2bit_prepack(
525 const Tensor& weight,
526 const bool optimized_qparams,
527 const int64_t nbins,
528 const double ratio) {
529 return _qembeddingbag_nbit_prepack_helper(
530 weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio);
531 }
532
533 class QEmbeddingPackWeights final {
534 public:
run(at::Tensor weight)535 static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(at::Tensor weight) {
536 return PackedEmbeddingBagWeight::prepack(std::move(weight));
537 }
538 };
539
TORCH_LIBRARY_IMPL(quantized,CPU,m)540 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
541 m.impl(
542 TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
543 TORCH_FN(qembeddingbag_byte_prepack));
544 m.impl(
545 TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
546 TORCH_FN(qembeddingbag_4bit_prepack));
547 m.impl(
548 TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
549 TORCH_FN(qembeddingbag_2bit_prepack));
550 }
551
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)552 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
553 m.impl(
554 TORCH_SELECTIVE_NAME("quantized::embedding_bag_prepack"),
555 TORCH_FN(QEmbeddingPackWeights::run));
556 }
557
TORCH_LIBRARY_IMPL(quantized,Meta,m)558 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
559 m.impl(
560 "quantized::embedding_bag_byte_prepack", qembeddingbag_byte_prepack_meta);
561 }
562
563 } // namespace
564 } // namespace native
565 } // namespace at
566