xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/QuantizedLinear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <vector>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/WrapDimUtilsMulti.h>
7 #include <ATen/cpp_custom_type_hack.h>
8 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
9 #include <ATen/native/quantized/PackedParams.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/empty_like_native.h>
17 #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
18 #include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
19 #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_native.h>
20 #include <ATen/ops/fbgemm_linear_int8_weight_native.h>
21 #include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
22 #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
23 #include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
24 #endif
25 
26 #include <c10/util/irange.h>
27 
28 #ifdef USE_FBGEMM
29 #include <fbgemm/Fbgemm.h>
30 #include <fbgemm/FbgemmFP16.h>
31 #include <fbgemm/QuantUtils.h>
32 #endif // USE_FBGEMM
33 
34 namespace caffe2 {
35 CAFFE_KNOWN_TYPE(c10::intrusive_ptr<LinearPackedParamsBase>);
36 } // namespace caffe2
37 
38 #ifdef USE_FBGEMM
39 namespace caffe2 {
40 // Required for cpp_custom_type_hack to work
41 CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
42 CAFFE_KNOWN_TYPE(c10::intrusive_ptr<PackedLinearWeightFp16>);
43 } // namespace caffe2
44 #endif // USE_FBGEMM
45 
46 namespace at::native {
47 
48 #ifdef USE_FBGEMM
49 
fbgemm_linear_int8_weight_fp32_activation(const Tensor & input,const Tensor & weight,const Tensor & packed,const Tensor & col_offsets,const Scalar & weight_scale,const Scalar & weight_zero_point,const Tensor & bias)50 Tensor fbgemm_linear_int8_weight_fp32_activation(
51     const Tensor& input,
52     const Tensor& weight,
53     const Tensor& packed,
54     const Tensor& col_offsets,
55     const Scalar& weight_scale,
56     const Scalar& weight_zero_point,
57     const Tensor& bias) {
58   // We make a strong guarantee that models using these operators will have the
59   // same numerics across different machines. Therefore, we do not provide a
60   // fallback path and rather fail loudly if we cannot run FBGEMM.
61   TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
62 
63   TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated "
64                   "and will be removed in a future PyTorch release.")
65 
66   const Tensor input_contig = input.contiguous();
67   const float* input_ptr = input_contig.const_data_ptr<float>();
68 
69   TORCH_CHECK(input.dim() >= 2);
70   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
71   const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
72   const int64_t K = input.size(input.dim() - 1);
73   TORCH_CHECK(weight.dim() == 2);
74   TORCH_CHECK(K == weight.size(1));
75   const int64_t N = weight.size(0);
76   TORCH_CHECK(bias.dim() == 1);
77   TORCH_CHECK(bias.size(0) == N);
78   TORCH_CHECK(weight_scale.isFloatingPoint());
79   TORCH_CHECK(weight_zero_point.isIntegral(false));
80 
81   // Calculate statistics for quantization of the input Tensor
82   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
83   float x_min;
84   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
85   float x_max;
86   fbgemm::FindMinMax(
87       /*m=*/input_ptr,
88       /*min=*/&x_min,
89       /*max=*/&x_max,
90       /*len=*/input.numel());
91 
92   // Input tensor is quantized as 8-bit unsigned values
93   constexpr int kPrecision = 8;
94   constexpr bool kIsSigned = false;
95   constexpr int kBound = (1 << (kPrecision - 1));
96 
97   // Calculate scale and zero point for quantization of input tensor
98   auto q_params = fbgemm::ChooseQuantizationParams(
99       /*min=*/x_min,
100       /*max=*/x_max,
101       /*qmin=*/kIsSigned ? -kBound : 0,
102       /*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1,
103       /*preserve_sparsity=*/false);
104   q_params.precision = kPrecision;
105 
106   // ReQuantizeForFloat requires pointers to the scale and zero point values,
107   // since in the case of rowwise quantization these will be arrays rather than
108   // scalars. But in this case, we're doing whole-tensor quantization so we just
109   // pass a pointer to the scale values (and internally ReQuantizeFor Float
110   // won't index past 0
111   const float weight_scale_float =
112       static_cast<float>(weight_scale.to<double>());
113   const int32_t weight_zero_point_int32 =
114       static_cast<int32_t>(weight_zero_point.to<int64_t>());
115 
116   const Tensor bias_contig = bias.contiguous();
117 
118   // Allocate output Tensor and a buffer for fbgemmPacked to use
119   std::vector<int64_t> output_size = input.sizes().vec();
120   output_size.back() = N;
121   Tensor output = at::empty(output_size, input.options().dtype(at::kFloat), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
122   Tensor buffer = at::empty(output_size, input.options().dtype(at::kInt), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
123 
124   // Pull out the PackBMatrix instance from the owning tensor
125   auto& pack_b =
126       cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed);
127 
128   int32_t* col_offsets_data = col_offsets.data_ptr<int32_t>();
129   float* bias_contig_data = bias_contig.data_ptr<float>();
130 
131   const int num_tasks = at::get_num_threads();
132   at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
133     // This operation does the following:
134     // 1) Quantizes the input matrix given the statistics we've calculated
135     //    above.
136     // 2) Creates a "row buffer" vector with offset values that must be added
137     //    to the integer matrix multiplication operation to ensure correctness.
138     // 3) Packs the resulting quantized matrix into vector-register and cache
139     //    friendly tiles.
140     //
141     //  Note this is not executed eagerly, but rather within the fbgemmPacked
142     //  call below.
143     fbgemm::PackAWithQuantRowOffset<uint8_t> pack_a(
144         /*trans=*/fbgemm::matrix_op_t::NoTranspose,
145         /*nRow=*/M,
146         /*nCol=*/K,
147         /*smat=*/input_ptr,
148         /*ld=*/K,
149         /*pmat=*/nullptr, // pack_a manages ownership of `pmat`
150         /*scale=*/q_params.scale,
151         /*zero_pt=*/q_params.zero_point);
152 
153     // This is the end of the pipeline, pass the resulting matrix through
154     fbgemm::DoNothing<float, float> kDoNothingObj{};
155     for (const auto task_id : c10::irange(begin, end)) {
156       // After the uint8 * int8 matrix multiplication is performed, this
157       // operation does:
158       //  1) Add in row and column offsets to the rows and columns, respectively
159       //  2) Dequantize the results into floating point
160       //  3) Add in the bias term
161       fbgemm::ReQuantizeForFloat</* FUSE_RELU */ false> output_proc_obj(
162           /*nextop=*/kDoNothingObj,
163           /*Aq_scale=*/q_params.scale,
164           /*Bq_scale=*/&weight_scale_float,
165           /*Aq_zero_point=*/q_params.zero_point,
166           /*Bq_zero_point=*/&weight_zero_point_int32,
167           /*row_offsets=*/pack_a.getRowOffsetBuffer(),
168           /*col_offsets=*/col_offsets_data,
169           /*bias=*/bias_contig_data,
170           /*nCol=*/N);
171       // Do the GEMM
172       fbgemm::fbgemmPacked(
173           /*packA=*/pack_a,
174           /*packB=*/pack_b,
175           /*C=*/output.data_ptr<float>(),
176           /*C_buffer=*/buffer.data_ptr<int32_t>(),
177           /*ldc=*/N,
178           /*outProcess=*/output_proc_obj,
179           /*thread_id=*/task_id,
180           /*num_threads=*/num_tasks);
181     }
182   });
183 
184   return output;
185 }
186 
fbgemm_linear_int8_weight(const Tensor & input,const Tensor & weight,const Tensor & packed,const Tensor & col_offsets,const Scalar & weight_scale,const Scalar & weight_zero_point,const Tensor & bias)187 Tensor fbgemm_linear_int8_weight(
188     const Tensor& input,
189     const Tensor& weight,
190     const Tensor& packed,
191     const Tensor& col_offsets,
192     const Scalar& weight_scale,
193     const Scalar& weight_zero_point,
194     const Tensor& bias) {
195   return at::native::fbgemm_linear_int8_weight_fp32_activation(
196       input,
197       weight,
198       packed,
199       col_offsets,
200       weight_scale,
201       weight_zero_point,
202       bias);
203 }
204 
205 namespace {
206 
207 // Calculate the column offsets
208 // Note this includes the sum of the columns as well as the scalar term
209 // B_zero_point * K, whereas the row_offsets created by
210 // PackAWithQuantRowOffset is only the sum of the A rows.
CalcColOffsetsTranspose(int K,int N,const int8_t * Bint8,int32_t B_zero_point,int32_t * col_offsets)211 void CalcColOffsetsTranspose(
212     int K,
213     int N,
214     const int8_t* Bint8,
215     int32_t B_zero_point,
216     int32_t* col_offsets) {
217   for (const auto i : c10::irange(N)) {
218     int32_t sum = 0;
219     for (const auto j : c10::irange(K)) {
220       sum += Bint8[i * K + j];
221     }
222     col_offsets[i] = sum - B_zero_point * K;
223   }
224 }
225 
226 } // namespace
227 
fbgemm_linear_quantize_weight(const Tensor & weight)228 std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
229     const Tensor& weight) {
230   TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated "
231                   "and will be removed in a future PyTorch release.")
232 
233   // We make a strong guarantee that models using these operators will have the
234   // same numerics across different machines. Therefore, we do not provide a
235   // fallback path and rather fail loudly if we cannot run FBGEMM.
236   TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
237   const Tensor weight_contig = weight.contiguous();
238 
239   // Calculate weight statistics
240   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
241   float w_min;
242   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
243   float w_max;
244   fbgemm::FindMinMax(
245       /*m=*/weight_contig.data_ptr<float>(),
246       /*min=*/&w_min,
247       /*max=*/&w_max,
248       /*len=*/weight_contig.numel());
249 
250   // Choose parameters for quantizing the weight as 8-bit signed integer
251   constexpr bool kIsSigned = true;
252   constexpr int kPrecision = 8;
253   constexpr int kBound = (1 << (kPrecision - 1));
254   auto q_params = fbgemm::ChooseQuantizationParams(
255       /*min=*/w_min,
256       /*max=*/w_max,
257       /*qmin=*/kIsSigned ? -kBound : 0,
258       /*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1,
259       /*preserve_sparsity=*/false);
260   q_params.precision = kPrecision;
261 
262   Tensor quantized = at::native::empty_like(
263       weight_contig,
264       at::kChar,
265       weight_contig.options().layout_opt(),
266       weight_contig.options().device_opt(),
267       weight_contig.options().pinned_memory_opt(),
268       LEGACY_CONTIGUOUS_MEMORY_FORMAT);
269   // Tensor quantized = at::native::empty_cpu(
270   //     weight_contig.sizes(), weight_contig.options().dtype(at::kChar));
271   fbgemm::Quantize<int8_t, false /*LEGACY*/>(
272       /*src=*/weight_contig.data_ptr<float>(),
273       /*dst=*/quantized.data_ptr<int8_t>(),
274       /*len=*/weight_contig.numel(),
275       /*qparams=*/q_params);
276 
277   // Calculate column offsets of the weight and store them away in a tensor.
278   // Similarly to quantization, this can be done once and cached.
279   Tensor col_offsets = at::empty(
280       {weight_contig.size(0)},
281       at::kInt,
282       weight_contig.options().layout_opt(),
283       weight_contig.options().device_opt(),
284       weight_contig.options().pinned_memory_opt(),
285       LEGACY_CONTIGUOUS_MEMORY_FORMAT);
286   CalcColOffsetsTranspose(
287       /*K=*/quantized.size(1),
288       /*N=*/quantized.size(0),
289       /*Bint8=*/quantized.data_ptr<int8_t>(),
290       /*B_zero_point=*/q_params.zero_point,
291       /*col_offsets=*/col_offsets.data_ptr<int32_t>());
292 
293   return std::make_tuple(
294       quantized, col_offsets, q_params.scale, q_params.zero_point);
295 }
296 
fbgemm_pack_quantized_matrix(const Tensor & weight)297 Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) {
298   TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
299                   "and will be removed in a future PyTorch release.")
300 
301   // We make a strong guarantee that models using these operators will have the
302   // same numerics across different machines. Therefore, we do not provide a
303   // fallback path and rather fail loudly if we cannot run FBGEMM.
304   TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
305   const int64_t K = weight.size(1);
306   const int64_t N = weight.size(0);
307   const Tensor weight_contig = weight.contiguous();
308   const int8_t* weight_ptr = weight_contig.const_data_ptr<int8_t>();
309   auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>(
310       /*trans=*/fbgemm::matrix_op_t::Transpose,
311       /*nRow=*/K,
312       /*nCol=*/N,
313       /*smat=*/weight_ptr,
314       /*ld=*/K,
315       /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
316       /*groups=*/1);
317   return cpp_custom_type_hack::create(std::move(ptr), weight.options());
318 }
319 
fbgemm_pack_quantized_matrix(const Tensor & weight,int64_t K,int64_t N)320 Tensor fbgemm_pack_quantized_matrix(
321     const Tensor& weight,
322     int64_t K,
323     int64_t N) {
324   // Replace after https://github.com/pytorch/pytorch/issues/24354 is fixed
325   // TORCH_WARN(
326   //     "fbgemm_pack_quantized_matrix(weight, K, N) will be deprecated soon."
327   //     "Please use fbgemm_pack_quantized_matrix(weight) instead.");
328   return at::native::fbgemm_pack_quantized_matrix(weight);
329 }
330 
331 namespace {
332 
RawUint16ToFp16(unsigned short value)333 float RawUint16ToFp16(unsigned short value) {
334   // Convert raw 16 bits half precision floating point number
335   // to single precision floating point number.
336   const unsigned short sign_bits = value >> 15;
337   const unsigned short exponent_bits = value >> 10 & 0x1f;
338   const unsigned short significand_bits = value & 0x3ff;
339 
340   const float sign = sign_bits ? -1 : 1;
341   const float significand =
342       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
343       1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10
344   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
345   const float exponent = exponent_bits - 0xf;
346 
347   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
348   return sign * std::ldexp(significand, exponent);
349 }
350 
351 template <typename T>
CheckAndSaturate(T max_val,T * element)352 bool CheckAndSaturate(T max_val, T* element) {
353   if (*element > max_val) {
354     *element = max_val;
355     return true;
356   }
357   if (*element < -max_val) {
358     *element = -max_val;
359     return true;
360   }
361   return false;
362 }
363 
364 // The range for using FP16 quantization of weights requires that the elements
365 // should be in the range of [5.96e-8, 65504]. If it is out of range, then the
366 // number will be saturated to max or min representable values by FP16.
HandleWeightsSaturation(int64_t N,float * weight)367 void HandleWeightsSaturation(int64_t N, float* weight) {
368   const float kFp16Max = RawUint16ToFp16(0x7BFF);
369   bool found_out_of_range = false;
370   for (const auto i : c10::irange(N)) {
371     if (CheckAndSaturate<float>(kFp16Max, weight + i)) {
372       found_out_of_range = true;
373     }
374   }
375   if (found_out_of_range) {
376     TORCH_WARN("FOUND weight out of range ");
377   }
378 }
379 
380 } // namespace
381 
fbgemm_pack_gemm_matrix_fp16(const Tensor & weight)382 Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
383   TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated "
384                   "and will be removed in a future PyTorch release.")
385 
386   // We make a strong guarantee that models using these operators will have the
387   // same numerics across different machines. Therefore, we do not provide a
388   // fallback path and rather fail loudly if we cannot run FBGEMM.
389   TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
390 
391   const int64_t K = weight.size(1);
392   const int64_t N = weight.size(0);
393   Tensor weight_contig = weight.contiguous();
394   float* weight_contig_ptr = weight_contig.data_ptr<float>();
395   HandleWeightsSaturation(K * N, weight_contig_ptr);
396 
397   // TODO(mingzhe09088):
398   // Consider using a functor here in PackedGemmMatrixFP16
399   // Comments from (XQ): Not entirely sure this make_unique is safe. make_unique
400   // is created with regular "new", and freed through TypeMetaData::deleteFn in
401   // this function. This is perfectly fine if the tensors are created and freed
402   // within this translation unit. It might be very problematic if that tensor
403   // flows across dll boundaries.
404   auto ptr = std::make_unique<fbgemm::PackedGemmMatrixFP16>(
405       fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr);
406   c10::intrusive_ptr<LinearPackedParamsBase> packed_weight =
407       c10::make_intrusive<PackedLinearWeightFp16>(std::move(ptr), std::nullopt);
408   auto unique_ptr_wrapper =
409       std::make_unique<decltype(packed_weight)>(std::move(packed_weight));
410   return cpp_custom_type_hack::create(
411       std::move(unique_ptr_wrapper), weight.options());
412 }
413 
fbgemm_linear_fp16_weight_fp32_activation(const Tensor & input,const Tensor & packed_weight,const Tensor & bias)414 Tensor fbgemm_linear_fp16_weight_fp32_activation(
415     const Tensor& input,
416     const Tensor& packed_weight,
417     const Tensor& bias) {
418   TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
419                   "and will be removed in a future PyTorch release.")
420 
421   // We make a strong guarantee that models using these operators will have the
422   // same numerics across different machines. Therefore, we do not provide a
423   // fallback path and rather fail loudly if we cannot run FBGEMM.
424   TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
425 
426   const Tensor input_contig = input.contiguous();
427   const float* input_ptr = input_contig.const_data_ptr<float>();
428 
429   // Pull out the PackedGemmMatrixFP16 instance from the owning tensor
430   const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 =
431       *c10::dynamic_intrusive_pointer_cast<PackedLinearWeightFp16>(
432            cpp_custom_type_hack::cast<
433                c10::intrusive_ptr<LinearPackedParamsBase>>(packed_weight))
434            ->w;
435 
436   TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
437   TORCH_CHECK(input.dim() >= 2);
438   TORCH_CHECK(bias.dim() == 1);
439 
440   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
441   const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
442   const int64_t N = packed_weight_fp16.numCols();
443   std::vector<int64_t> output_size = input.sizes().vec();
444   output_size.back() = N;
445   Tensor output = at::empty(output_size, input.options().dtype(at::kFloat));
446 
447   // Call the fp16 gemm interface
448   fbgemm::cblas_gemm_compute(
449       fbgemm::matrix_op_t::NoTranspose,
450       M,
451       input_ptr,
452       packed_weight_fp16,
453       0.0f,
454       output.data_ptr<float>());
455 
456   // Add bias term
457   output.add_(bias);
458 
459   return output;
460 }
461 
fbgemm_linear_fp16_weight(const Tensor & input,const Tensor & packed_weight,const Tensor & bias)462 Tensor fbgemm_linear_fp16_weight(
463     const Tensor& input,
464     const Tensor& packed_weight,
465     const Tensor& bias) {
466   return at::native::fbgemm_linear_fp16_weight_fp32_activation(
467       input, packed_weight, bias);
468 }
469 
470 #else // USE_FBGEMM
471 
472 Tensor fbgemm_linear_int8_weight_fp32_activation(
473     const Tensor& /*input*/,
474     const Tensor& /*weight*/,
475     const Tensor& /*packed*/,
476     const Tensor& /*col_offsets*/,
477     const Scalar& /*weight_scale*/,
478     const Scalar& /*weight_zero_point*/,
479     const Tensor& /*bias*/) {
480   TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated "
481                   "and will be removed in a future PyTorch release.")
482 
483   // We make a strong guarantee that models using these operators will have the
484   // same numerics across different machines. Therefore, we do not provide a
485   // fallback path and rather fail loudly if we cannot run FBGEMM.
486   TORCH_CHECK(
487       false, "This PyTorch installation was not built with FBGEMM operators");
488 }
489 
490 Tensor fbgemm_linear_int8_weight(
491     const Tensor& /*input*/,
492     const Tensor& /*weight*/,
493     const Tensor& /*packed*/,
494     const Tensor& /*col_offsets*/,
495     const Scalar& /*weight_scale*/,
496     const Scalar& /*weight_zero_point*/,
497     const Tensor& /*bias*/) {
498   TORCH_WARN_ONCE("fbgemm_linear_int8_weight is deprecated "
499                   "and will be removed in a future PyTorch release.")
500 
501   // We make a strong guarantee that models using these operators will have the
502   // same numerics across different machines. Therefore, we do not provide a
503   // fallback path and rather fail loudly if we cannot run FBGEMM.
504   TORCH_CHECK(
505       false, "This PyTorch installation was not built with FBGEMM operators");
506 }
507 
508 std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
509     const Tensor& /*weight*/) {
510   TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated "
511                   "and will be removed in a future PyTorch release.")
512 
513   // We make a strong guarantee that models using these operators will have the
514   // same numerics across different machines. Therefore, we do not provide a
515   // fallback path and rather fail loudly if we cannot run FBGEMM.
516   TORCH_CHECK(
517       false, "This PyTorch installation was not built with FBGEMM operators");
518 }
519 
520 Tensor fbgemm_pack_quantized_matrix(const Tensor& /*input*/) {
521   TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
522                   "and will be removed in a future PyTorch release.")
523 
524   // We make a strong guarantee that models using these operators will have the
525   // same numerics across different machines. Therefore, we do not provide a
526   // fallback path and rather fail loudly if we cannot run FBGEMM.
527   TORCH_CHECK(
528       false, "This PyTorch installation was not built with FBGEMM operators");
529 }
530 
531 Tensor fbgemm_pack_quantized_matrix(
532     const Tensor& /*input*/,
533     int64_t /*K*/,
534     int64_t /*N*/) {
535   TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
536                   "and will be removed in a future PyTorch release.")
537 
538   // We make a strong guarantee that models using these operators will have the
539   // same numerics across different machines. Therefore, we do not provide a
540   // fallback path and rather fail loudly if we cannot run FBGEMM.
541   TORCH_CHECK(
542       false, "This PyTorch installation was not built with FBGEMM operators");
543 }
544 
545 Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
546   TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated "
547                   "and will be removed in a future PyTorch release.")
548 
549   // We make a strong guarantee that models using these operators will have the
550   // same numerics across different machines. Therefore, we do not provide a
551   // fallback path and rather fail loudly if we cannot run FBGEMM.
552   TORCH_CHECK(
553       false, "This PyTorch installation was not built with FBGEMM operators");
554 }
555 
556 Tensor fbgemm_linear_fp16_weight_fp32_activation(
557     const Tensor& input,
558     const Tensor& packed_weight,
559     const Tensor& bias) {
560   TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
561                   "and will be removed in a future PyTorch release.")
562 
563   // We make a strong guarantee that models using these operators will have the
564   // same numerics across different machines. Therefore, we do not provide a
565   // fallback path and rather fail loudly if we cannot run FBGEMM.
566   TORCH_CHECK(
567       false, "This PyTorch installation was not built with FBGEMM operators");
568 }
569 
570 Tensor fbgemm_linear_fp16_weight(
571     const Tensor& input,
572     const Tensor& packed_weight,
573     const Tensor& bias) {
574   TORCH_WARN_ONCE("fbgemm_linear_fp16_weight is deprecated "
575                   "and will be removed in a future PyTorch release.")
576 
577   // We make a strong guarantee that models using these operators will have the
578   // same numerics across different machines. Therefore, we do not provide a
579   // fallback path and rather fail loudly if we cannot run FBGEMM.
580   TORCH_CHECK(
581       false, "This PyTorch installation was not built with FBGEMM operators");
582 }
583 
584 #endif // USE_FBGEMM
585 
586 } // namespace at::native
587