1 #include <torch/csrc/distributed/c10d/quantization/quantization.h>
2 #include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
3 #include <torch/library.h>
4
5 namespace torch::distributed::c10d::quantization {
6
7 // TODO: The kernels are copied from fbgemm_gpu, we should dedup them later
8
FloatToBFloat16Quantized_ref(const float * const input,const size_t nrows,const size_t ncols,uint16_t * const output)9 static void FloatToBFloat16Quantized_ref(
10 const float* const input,
11 const size_t nrows,
12 const size_t ncols,
13 uint16_t* const output) {
14 for (const auto row : c10::irange(nrows)) {
15 const float* input_row = input + row * ncols;
16 uint16_t* output_row = output + row * ncols;
17
18 for (const auto col : c10::irange(ncols)) {
19 output_row[col] =
20 (*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
21 16;
22 }
23 }
24 }
25
BFloat16QuantizedToFloat_ref(const at::BFloat16 * const input,const size_t nrows,const size_t ncols,float * const output)26 static void BFloat16QuantizedToFloat_ref(
27 const at::BFloat16* const input,
28 const size_t nrows,
29 const size_t ncols,
30 float* const output) {
31 for (const auto row : c10::irange(nrows)) {
32 const at::BFloat16* input_row = input + row * ncols;
33 float* output_row = output + row * ncols;
34
35 for (const auto col : c10::irange(ncols)) {
36 uint32_t val_fp32 = static_cast<uint32_t>(
37 reinterpret_cast<const uint16_t*>(input_row)[col])
38 << 16;
39 reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
40 }
41 }
42 }
43
_float_to_bfloat16_cpu(const at::Tensor & input)44 at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) {
45 TENSOR_ON_CPU(input);
46 // Currently it supports 2D inputs
47 TENSOR_NDIM_EQUALS(input, 2);
48
49 const auto input_sizes = input.sizes();
50 const auto nrows = input_sizes[0];
51 const auto ncols = input_sizes[1];
52 auto output = at::empty({nrows, ncols}, input.options().dtype(at::kHalf));
53
54 FloatToBFloat16Quantized_ref(
55 input.const_data_ptr<float>(),
56 nrows,
57 ncols,
58 reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>()));
59
60 return output;
61 }
62
_bfloat16_to_float_cpu(const at::Tensor & input)63 at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
64 TENSOR_ON_CPU(input);
65 // Currently it supports 2D inputs
66 TENSOR_NDIM_EQUALS(input, 2);
67
68 const auto input_sizes = input.sizes();
69 const auto nrows = input_sizes[0];
70 const auto ncols = input_sizes[1];
71
72 auto output = at::empty({nrows, ncols}, input.options().dtype(at::kFloat));
73 BFloat16QuantizedToFloat_ref(
74 reinterpret_cast<const at::BFloat16*>(input.const_data_ptr<at::Half>()),
75 nrows,
76 ncols,
77 output.mutable_data_ptr<float>());
78
79 return output;
80 }
81
TORCH_LIBRARY(quantization,m)82 TORCH_LIBRARY(quantization, m) {
83 m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor");
84 m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor");
85 }
86
TORCH_LIBRARY_IMPL(quantization,CPU,m)87 TORCH_LIBRARY_IMPL(quantization, CPU, m) {
88 m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu);
89 m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu);
90 }
91
92 } // namespace torch::distributed::c10d::quantization
93