xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/quantization/quantization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/quantization/quantization.h>
2 #include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
3 #include <torch/library.h>
4 
5 namespace torch::distributed::c10d::quantization {
6 
7 // TODO: The kernels are copied from fbgemm_gpu, we should dedup them later
8 
FloatToBFloat16Quantized_ref(const float * const input,const size_t nrows,const size_t ncols,uint16_t * const output)9 static void FloatToBFloat16Quantized_ref(
10     const float* const input,
11     const size_t nrows,
12     const size_t ncols,
13     uint16_t* const output) {
14   for (const auto row : c10::irange(nrows)) {
15     const float* input_row = input + row * ncols;
16     uint16_t* output_row = output + row * ncols;
17 
18     for (const auto col : c10::irange(ncols)) {
19       output_row[col] =
20           (*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
21           16;
22     }
23   }
24 }
25 
BFloat16QuantizedToFloat_ref(const at::BFloat16 * const input,const size_t nrows,const size_t ncols,float * const output)26 static void BFloat16QuantizedToFloat_ref(
27     const at::BFloat16* const input,
28     const size_t nrows,
29     const size_t ncols,
30     float* const output) {
31   for (const auto row : c10::irange(nrows)) {
32     const at::BFloat16* input_row = input + row * ncols;
33     float* output_row = output + row * ncols;
34 
35     for (const auto col : c10::irange(ncols)) {
36       uint32_t val_fp32 = static_cast<uint32_t>(
37                               reinterpret_cast<const uint16_t*>(input_row)[col])
38           << 16;
39       reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
40     }
41   }
42 }
43 
_float_to_bfloat16_cpu(const at::Tensor & input)44 at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) {
45   TENSOR_ON_CPU(input);
46   // Currently it supports 2D inputs
47   TENSOR_NDIM_EQUALS(input, 2);
48 
49   const auto input_sizes = input.sizes();
50   const auto nrows = input_sizes[0];
51   const auto ncols = input_sizes[1];
52   auto output = at::empty({nrows, ncols}, input.options().dtype(at::kHalf));
53 
54   FloatToBFloat16Quantized_ref(
55       input.const_data_ptr<float>(),
56       nrows,
57       ncols,
58       reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>()));
59 
60   return output;
61 }
62 
_bfloat16_to_float_cpu(const at::Tensor & input)63 at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
64   TENSOR_ON_CPU(input);
65   // Currently it supports 2D inputs
66   TENSOR_NDIM_EQUALS(input, 2);
67 
68   const auto input_sizes = input.sizes();
69   const auto nrows = input_sizes[0];
70   const auto ncols = input_sizes[1];
71 
72   auto output = at::empty({nrows, ncols}, input.options().dtype(at::kFloat));
73   BFloat16QuantizedToFloat_ref(
74       reinterpret_cast<const at::BFloat16*>(input.const_data_ptr<at::Half>()),
75       nrows,
76       ncols,
77       output.mutable_data_ptr<float>());
78 
79   return output;
80 }
81 
TORCH_LIBRARY(quantization,m)82 TORCH_LIBRARY(quantization, m) {
83   m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor");
84   m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor");
85 }
86 
TORCH_LIBRARY_IMPL(quantization,CPU,m)87 TORCH_LIBRARY_IMPL(quantization, CPU, m) {
88   m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu);
89   m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu);
90 }
91 
92 } // namespace torch::distributed::c10d::quantization
93