xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/cuda/CUDAGuard.h>
2 #include <torch/csrc/distributed/c10d/Utils.hpp>
3 #include <torch/csrc/distributed/c10d/quantization/quantization_gpu.h>
4 #include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
5 #include <torch/library.h>
6 
7 // TODO: The kernels are copied from fbgemm_gpu, we should dedup them later
8 
9 // FP32 -> BF16 kernel
_float_to_bfloat16_cuda_kernel(const float * __restrict__ input,const size_t nrows,const size_t ncols,uint16_t * __restrict__ output)10 __global__ void _float_to_bfloat16_cuda_kernel(
11     const float* __restrict__ input,
12     const size_t nrows,
13     const size_t ncols,
14     uint16_t* __restrict__ output) {
15   const auto row_incre = blockDim.y * gridDim.y;
16   const auto col_incre = blockDim.x * gridDim.x;
17   for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
18        row += row_incre) {
19     const float* input_row = input + row * ncols;
20     uint16_t* output_row = output + row * ncols;
21     for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
22          col += col_incre) {
23       // Add 2^15 and right shift 16 to do round-nearest
24       output_row[col] =
25           (*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
26           16;
27     }
28   }
29 }
30 
31 // BF16 -> FP32 kernel
_bfloat16_to_float_cuda_kernel(const uint16_t * __restrict__ input,const size_t nrows,const size_t ncols,float * __restrict__ output)32 __global__ void _bfloat16_to_float_cuda_kernel(
33     const uint16_t* __restrict__ input,
34     const size_t nrows,
35     const size_t ncols,
36     float* __restrict__ output) {
37   const auto row_incre = blockDim.y * gridDim.y;
38   const auto col_incre = blockDim.x * gridDim.x;
39   for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
40        row += row_incre) {
41     for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
42          col += col_incre) {
43       const uint16_t* input_row = input + row * ncols;
44       float* output_row = output + row * ncols;
45       uint32_t val_fp32 = static_cast<uint32_t>(
46                               reinterpret_cast<const uint16_t*>(input_row)[col])
47           << 16;
48       reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
49     }
50   }
51 }
52 
53 namespace torch::distributed::c10d::quantization {
54 
_float_to_bfloat16_cuda(const at::Tensor & input)55 at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
56   TENSOR_ON_CUDA_GPU(input);
57   // Currently it supports 2D inputs
58   TENSOR_NDIM_EQUALS(input, 2);
59 
60   at::cuda::OptionalCUDAGuard device_guard;
61   device_guard.set_index(input.get_device());
62 
63   const auto nrows = input.size(0);
64   const auto ncols = input.size(1);
65   const size_t output_columns = ncols;
66 
67   auto output = at::empty(
68       {nrows, ncols},
69 #if HAS_NCCL_BF16_DATATYPE
70       input.options().dtype(at::kBFloat16));
71 #else
72       input.options().dtype(at::kHalf));
73 #endif
74 
75   if (nrows == 0 || ncols == 0) {
76     return output;
77   }
78 
79   constexpr size_t threads_per_block = 256;
80   const auto blockDim_x = std::min(output_columns, threads_per_block);
81   dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
82   const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
83   const auto gridDim_y =
84       std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
85   dim3 gridDim(gridDim_x, gridDim_y);
86 
87   _float_to_bfloat16_cuda_kernel<<<
88       gridDim,
89       blockDim,
90       0,
91       at::cuda::getCurrentCUDAStream()>>>(
92       input.const_data_ptr<float>(),
93       nrows,
94       ncols,
95 #if HAS_NCCL_BF16_DATATYPE
96       reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::BFloat16>())
97 #else
98       reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>())
99 #endif
100       );
101   C10_CUDA_KERNEL_LAUNCH_CHECK();
102 
103   return output;
104 }
105 
_bfloat16_to_float_cuda(const at::Tensor & input)106 at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
107   TENSOR_ON_CUDA_GPU(input);
108   // Currently it supports 2D inputs
109   TENSOR_NDIM_EQUALS(input, 2);
110 
111   at::cuda::OptionalCUDAGuard device_guard;
112   device_guard.set_index(input.get_device());
113 
114   const auto nrows = input.size(0);
115   const auto ncols = input.size(1);
116   const size_t output_columns = ncols;
117 
118   auto output = at::empty(
119       {nrows, ncols}, // 4 = sizeof(float)
120       input.options().dtype(at::kFloat)); // at::kBytes for uint8_t
121 
122   if (nrows == 0 || ncols == 0) {
123     return output;
124   }
125 
126   constexpr size_t threads_per_block = 256;
127 
128   const auto blockDim_x = std::min(output_columns, threads_per_block);
129   dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
130   const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
131   const auto gridDim_y =
132       std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
133   dim3 gridDim(gridDim_x, gridDim_y);
134 
135   _bfloat16_to_float_cuda_kernel<<<
136       gridDim,
137       blockDim,
138       0,
139       at::cuda::getCurrentCUDAStream()>>>(
140 #if HAS_NCCL_BF16_DATATYPE
141       reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::BFloat16>()),
142 #else
143       reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::Half>()),
144 #endif
145       nrows,
146       ncols,
147       output.mutable_data_ptr<float>());
148   C10_CUDA_KERNEL_LAUNCH_CHECK();
149 
150   return output;
151 }
152 
153 #define DISPATCH_TO_CUDA(name, function) \
154   m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
155 
TORCH_LIBRARY_IMPL(quantization,CUDA,m)156 TORCH_LIBRARY_IMPL(quantization, CUDA, m) {
157   DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda);
158   DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda);
159 }
160 
161 } // namespace torch::distributed::c10d::quantization
162