1 #include <c10/cuda/CUDAGuard.h>
2 #include <torch/csrc/distributed/c10d/Utils.hpp>
3 #include <torch/csrc/distributed/c10d/quantization/quantization_gpu.h>
4 #include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
5 #include <torch/library.h>
6
7 // TODO: The kernels are copied from fbgemm_gpu, we should dedup them later
8
9 // FP32 -> BF16 kernel
_float_to_bfloat16_cuda_kernel(const float * __restrict__ input,const size_t nrows,const size_t ncols,uint16_t * __restrict__ output)10 __global__ void _float_to_bfloat16_cuda_kernel(
11 const float* __restrict__ input,
12 const size_t nrows,
13 const size_t ncols,
14 uint16_t* __restrict__ output) {
15 const auto row_incre = blockDim.y * gridDim.y;
16 const auto col_incre = blockDim.x * gridDim.x;
17 for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
18 row += row_incre) {
19 const float* input_row = input + row * ncols;
20 uint16_t* output_row = output + row * ncols;
21 for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
22 col += col_incre) {
23 // Add 2^15 and right shift 16 to do round-nearest
24 output_row[col] =
25 (*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
26 16;
27 }
28 }
29 }
30
31 // BF16 -> FP32 kernel
_bfloat16_to_float_cuda_kernel(const uint16_t * __restrict__ input,const size_t nrows,const size_t ncols,float * __restrict__ output)32 __global__ void _bfloat16_to_float_cuda_kernel(
33 const uint16_t* __restrict__ input,
34 const size_t nrows,
35 const size_t ncols,
36 float* __restrict__ output) {
37 const auto row_incre = blockDim.y * gridDim.y;
38 const auto col_incre = blockDim.x * gridDim.x;
39 for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
40 row += row_incre) {
41 for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
42 col += col_incre) {
43 const uint16_t* input_row = input + row * ncols;
44 float* output_row = output + row * ncols;
45 uint32_t val_fp32 = static_cast<uint32_t>(
46 reinterpret_cast<const uint16_t*>(input_row)[col])
47 << 16;
48 reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
49 }
50 }
51 }
52
53 namespace torch::distributed::c10d::quantization {
54
_float_to_bfloat16_cuda(const at::Tensor & input)55 at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
56 TENSOR_ON_CUDA_GPU(input);
57 // Currently it supports 2D inputs
58 TENSOR_NDIM_EQUALS(input, 2);
59
60 at::cuda::OptionalCUDAGuard device_guard;
61 device_guard.set_index(input.get_device());
62
63 const auto nrows = input.size(0);
64 const auto ncols = input.size(1);
65 const size_t output_columns = ncols;
66
67 auto output = at::empty(
68 {nrows, ncols},
69 #if HAS_NCCL_BF16_DATATYPE
70 input.options().dtype(at::kBFloat16));
71 #else
72 input.options().dtype(at::kHalf));
73 #endif
74
75 if (nrows == 0 || ncols == 0) {
76 return output;
77 }
78
79 constexpr size_t threads_per_block = 256;
80 const auto blockDim_x = std::min(output_columns, threads_per_block);
81 dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
82 const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
83 const auto gridDim_y =
84 std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
85 dim3 gridDim(gridDim_x, gridDim_y);
86
87 _float_to_bfloat16_cuda_kernel<<<
88 gridDim,
89 blockDim,
90 0,
91 at::cuda::getCurrentCUDAStream()>>>(
92 input.const_data_ptr<float>(),
93 nrows,
94 ncols,
95 #if HAS_NCCL_BF16_DATATYPE
96 reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::BFloat16>())
97 #else
98 reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>())
99 #endif
100 );
101 C10_CUDA_KERNEL_LAUNCH_CHECK();
102
103 return output;
104 }
105
_bfloat16_to_float_cuda(const at::Tensor & input)106 at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
107 TENSOR_ON_CUDA_GPU(input);
108 // Currently it supports 2D inputs
109 TENSOR_NDIM_EQUALS(input, 2);
110
111 at::cuda::OptionalCUDAGuard device_guard;
112 device_guard.set_index(input.get_device());
113
114 const auto nrows = input.size(0);
115 const auto ncols = input.size(1);
116 const size_t output_columns = ncols;
117
118 auto output = at::empty(
119 {nrows, ncols}, // 4 = sizeof(float)
120 input.options().dtype(at::kFloat)); // at::kBytes for uint8_t
121
122 if (nrows == 0 || ncols == 0) {
123 return output;
124 }
125
126 constexpr size_t threads_per_block = 256;
127
128 const auto blockDim_x = std::min(output_columns, threads_per_block);
129 dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
130 const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
131 const auto gridDim_y =
132 std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
133 dim3 gridDim(gridDim_x, gridDim_y);
134
135 _bfloat16_to_float_cuda_kernel<<<
136 gridDim,
137 blockDim,
138 0,
139 at::cuda::getCurrentCUDAStream()>>>(
140 #if HAS_NCCL_BF16_DATATYPE
141 reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::BFloat16>()),
142 #else
143 reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::Half>()),
144 #endif
145 nrows,
146 ncols,
147 output.mutable_data_ptr<float>());
148 C10_CUDA_KERNEL_LAUNCH_CHECK();
149
150 return output;
151 }
152
153 #define DISPATCH_TO_CUDA(name, function) \
154 m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
155
TORCH_LIBRARY_IMPL(quantization,CUDA,m)156 TORCH_LIBRARY_IMPL(quantization, CUDA, m) {
157 DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda);
158 DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda);
159 }
160
161 } // namespace torch::distributed::c10d::quantization
162