1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/layer_norm.h>
3
4 #include <type_traits>
5
6 #include <thrust/tuple.h>
7
8 #include <ATen/core/Tensor.h>
9 #include <ATen/AccumulateType.h>
10 #include <ATen/Dispatch.h>
11 #include <ATen/cuda/CUDAContext.h>
12 #include <ATen/cuda/detail/IndexUtils.cuh>
13 #include <ATen/native/cuda/block_reduce.cuh>
14 #include <ATen/native/cuda/thread_constants.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/empty_like_native.h>
22 #include <ATen/ops/native_layer_norm_native.h>
23 #include <ATen/ops/native_layer_norm_backward_native.h>
24 #include <ATen/ops/zeros_like_native.h>
25 #endif
26
27 #include <c10/cuda/CUDAMathCompat.h>
28 #include <c10/util/env.h>
29
30
31 namespace at::native {
32
33 namespace {
34
35 constexpr int kCUDANumThreads = 256;
36 constexpr unsigned int kWarpSize = C10_WARP_SIZE;
37 constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types
38
39 // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh)
40 template<typename scalar_t, int vec_size>
41 struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
42 scalar_t val[vec_size];
43 };
44
45 // Checks alignment of buffers for using vectorized loads / stores
46 template<typename T>
can_vectorize(const T * ptr,int alignment)47 bool can_vectorize(const T * ptr, int alignment) {
48 uint64_t addr = reinterpret_cast<uint64_t>(ptr);
49 return addr % alignment == 0;
50 };
51
52
53 template <typename T, typename T_ACC>
RowwiseMomentsCUDAKernel(int64_t N,T_ACC eps,const T * X,T_ACC * mean,T_ACC * rstd)54 __global__ void RowwiseMomentsCUDAKernel(
55 int64_t N,
56 T_ACC eps,
57 const T* X,
58 T_ACC* mean,
59 T_ACC* rstd) {
60 using WelfordType = WelfordData<T_ACC, int64_t>;
61 using WelfordOp =
62 WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
63
64 __shared__
65 typename std::aligned_storage<sizeof(WelfordType), alignof(WelfordType)>::
66 type val_shared[C10_WARP_SIZE];
67 WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
68
69 const int64_t i = blockIdx.x;
70 WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
71 WelfordType val(0, 0, 0, 0);
72
73 for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
74 const int64_t index = i * N + j;
75 val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
76 }
77 val = cuda_utils::BlockReduce(
78 val,
79 welford_op,
80 /*identity_element=*/WelfordType(0, 0, 0, 0),
81 val_shared_ptr);
82
83 if (threadIdx.x == 0) {
84 T_ACC m1;
85 T_ACC m2;
86 thrust::tie(m2, m1) = welford_op.project(val);
87 mean[i] = m1;
88 rstd[i] = c10::cuda::compat::rsqrt(m2 + eps);
89 }
90 }
91
92 template <typename T, typename T_ACC>
LayerNormForwardCUDAKernel(int64_t N,const T * X,const T_ACC * mean,const T_ACC * rstd,const T * gamma,const T * beta,T * Y)93 __global__ void LayerNormForwardCUDAKernel(
94 int64_t N,
95 const T* X,
96 const T_ACC* mean,
97 const T_ACC* rstd,
98 const T* gamma,
99 const T* beta,
100 T* Y) {
101 const int64_t i = blockIdx.x;
102 for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
103 const int64_t index = i * N + j;
104 const T_ACC gamma_v =
105 gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]);
106 const T_ACC beta_v =
107 beta == nullptr ? T_ACC(0) : static_cast<T_ACC>(beta[j]);
108 Y[index] = (static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) *
109 static_cast<T_ACC>(rstd[i]) * gamma_v +
110 beta_v;
111 }
112 }
113
114 struct WelfordDataLN{
115 float mean;
116 float sigma2;
117 float count;
WelfordDataLNat::native::__anonf65285a40111::WelfordDataLN118 C10_HOST_DEVICE WelfordDataLN(): mean(0.f), sigma2(0.f), count(0.f){}
WelfordDataLNat::native::__anonf65285a40111::WelfordDataLN119 C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {}
120 };
121
122 template<typename U> __device__
cuWelfordOnlineSum(const U val,const WelfordDataLN & curr_sum)123 WelfordDataLN cuWelfordOnlineSum(
124 const U val,
125 const WelfordDataLN& curr_sum)
126 {
127 U delta = val - curr_sum.mean;
128 U new_count = curr_sum.count + 1.f;
129 U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
130 return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
131 }
132
133 __device__
cuWelfordCombine(const WelfordDataLN dataB,const WelfordDataLN dataA)134 WelfordDataLN cuWelfordCombine(
135 const WelfordDataLN dataB,
136 const WelfordDataLN dataA
137 ) {
138 using U = decltype(dataB.count);
139 U delta = dataB.mean - dataA.mean;
140 U count = dataA.count + dataB.count;
141 U mean, sigma2;
142 if (count > decltype(dataB.count){0}) {
143 auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
144 auto nA = dataA.count * coef;
145 auto nB = dataB.count * coef;
146 mean = nA*dataA.mean + nB*dataB.mean;
147 sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB;
148 } else {
149 mean = U(0);
150 sigma2 = U(0);
151 }
152 return {mean, sigma2, count};
153 }
154
155 template<typename T>
compute_stats(const T * __restrict__ X,const int N,float * buf)156 __device__ WelfordDataLN compute_stats(
157 const T* __restrict__ X,
158 const int N,
159 float * buf
160 ) {
161 //X points to the row to read
162 using vec_t = aligned_vector<T, vec_size>;
163 using acc_t = acc_type<T, true>;
164 const vec_t * X_vec = reinterpret_cast<const vec_t*>(X);
165 const int numx = blockDim.x * blockDim.y;
166 const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
167 const int n_vec_to_read = N/vec_size;
168 WelfordDataLN wd(0.f, 0.f, 0.f);
169 //no tail, we check that N is multiple of vec_size
170 for (int i = thrx; i < n_vec_to_read; i += numx) {
171 vec_t data = X_vec[i];
172 #pragma unroll
173 for (int ii=0; ii < vec_size; ii++){
174 wd = cuWelfordOnlineSum(static_cast<acc_t>(data.val[ii]), wd);
175 }
176 }
177 // intra-warp reduction
178 for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
179 WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset),
180 WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)};
181 wd = cuWelfordCombine(wd, wdB);
182 }
183 // threadIdx.x == 0 has correct values for each warp
184 // inter-warp reductions
185 if (blockDim.y > 1) {
186 float * meansigmabuf = buf;
187 float * countbuf = buf + blockDim.y;
188 for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
189 // upper half of warps write to shared
190 if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
191 const int wrt_y = threadIdx.y - offset;
192 meansigmabuf[2*wrt_y] = wd.mean;
193 meansigmabuf[2*wrt_y+1] = wd.sigma2;
194 countbuf[wrt_y] = wd.count;
195 }
196 __syncthreads();
197 // lower half merges
198 if (threadIdx.x == 0 && threadIdx.y < offset) {
199 WelfordDataLN wdB{meansigmabuf[2*threadIdx.y],
200 meansigmabuf[2*threadIdx.y+1],
201 countbuf[threadIdx.y]};
202 wd = cuWelfordCombine(wd, wdB);
203 }
204 __syncthreads();
205 }
206 if (threadIdx.x == 0 && threadIdx.y ==0) {
207 meansigmabuf[0] = wd.mean;
208 meansigmabuf[1] = wd.sigma2/float(N);
209 }
210 __syncthreads();
211 return WelfordDataLN{meansigmabuf[0], meansigmabuf[1],0.f};
212
213 } else {
214 return WelfordDataLN{WARP_SHFL(wd.mean,0), WARP_SHFL(wd.sigma2,0)/float(N), 0.f};
215 }
216 }
217
218
219 template <typename T, typename T_ACC,
220 typename std::enable_if<!std::is_same<T, double>::value, int>::type = 0>
vectorized_layer_norm_kernel_impl(const int N,T_ACC eps,const T * __restrict__ X,const T * gamma,const T * beta,T_ACC * mean,T_ACC * rstd,T * Y)221 __device__ __inline__ void vectorized_layer_norm_kernel_impl(
222 const int N,
223 T_ACC eps,
224 const T* __restrict__ X,
225 const T* gamma,
226 const T* beta,
227 T_ACC* mean,
228 T_ACC* rstd,
229 T* Y){
230 extern __shared__ float s_data[]; //if we made smem WelfordDataLN type, there would be bank conflicts,
231 //as one thread would have to write 3 consecutive floats
232 auto i1 = blockIdx.x;
233 const T * block_row = X + i1 * N;
234 WelfordDataLN wd = compute_stats(block_row, N, s_data);
235
236 using vec_t = aligned_vector<T, vec_size>;
237 const vec_t * X_vec = reinterpret_cast<const vec_t*>(block_row);
238 const vec_t * gamma_vec = (gamma != nullptr) ? reinterpret_cast<const vec_t*>(gamma) : nullptr;
239 const vec_t * beta_vec = (beta != nullptr) ? reinterpret_cast<const vec_t*>(beta) : nullptr;
240 vec_t * Y_vec = reinterpret_cast<vec_t*>(Y + i1 * N);
241
242 const int numx = blockDim.x * blockDim.y;
243 const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
244 const int n_vec_to_read = N/vec_size;
245
246 T_ACC rstd_val = c10::cuda::compat::rsqrt(wd.sigma2 + eps);
247
248 // No tail, N is guaranteed to be multiple of vec size
249 for (int i = thrx; i < n_vec_to_read; i += numx) {
250 vec_t data = X_vec[i];
251 vec_t out;
252
253 // Computation is performed in T_ACC, X is cast to T_ACC and result is implicitly cast to T
254 if (gamma_vec != nullptr && beta_vec != nullptr) {
255 #pragma unroll
256 for (int ii=0; ii < vec_size; ii++){
257 out.val[ii] = static_cast<T_ACC>(gamma_vec[i].val[ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean))
258 + static_cast<T_ACC>(beta_vec[i].val[ii]);
259 }
260 } else if (gamma_vec != nullptr) {
261 #pragma unroll
262 for (int ii=0; ii < vec_size; ii++){
263 out.val[ii] = static_cast<T_ACC>(gamma_vec[i].val[ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean));
264 }
265 } else if (beta_vec != nullptr) {
266 #pragma unroll
267 for (int ii=0; ii < vec_size; ii++){
268 out.val[ii] = (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean)) + static_cast<T_ACC>(beta_vec[i].val[ii]);
269 }
270 } else {
271 #pragma unroll
272 for (int ii=0; ii < vec_size; ii++){
273 out.val[ii] = rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean);
274 }
275 }
276 Y_vec[i] = out;
277 }
278 if (thrx == 0) {
279 mean[i1] = wd.mean;
280 rstd[i1] = rstd_val;
281 }
282 }
283
284 template <typename T, typename T_ACC,
285 typename std::enable_if<std::is_same<T, double>::value, int>::type = 0>
vectorized_layer_norm_kernel_impl(const int,T_ACC,const T * __restrict__,const T *,const T *,T_ACC *,T_ACC *,T *)286 __device__ __inline__ void vectorized_layer_norm_kernel_impl(
287 const int /*N*/,
288 T_ACC /*eps*/,
289 const T* __restrict__ /*X*/,
290 const T* /*gamma*/,
291 const T* /*beta*/,
292 T_ACC* /*mean*/,
293 T_ACC* /*rstd*/,
294 T* /*Y*/){
295 CUDA_KERNEL_ASSERT(false && "doesn't work with double");
296 }
297
298 //to avoid windows SFINAE errors
299 template <typename T, typename T_ACC>
vectorized_layer_norm_kernel(const int N,T_ACC eps,const T * __restrict__ X,const T * gamma,const T * beta,T_ACC * mean,T_ACC * rstd,T * Y)300 __global__ void vectorized_layer_norm_kernel(
301 const int N,
302 T_ACC eps,
303 const T* __restrict__ X,
304 const T* gamma,
305 const T* beta,
306 T_ACC* mean,
307 T_ACC* rstd,
308 T* Y){
309 vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y);
310 }
311
312
313 template<typename T, typename T_ACC>
compute_gI(const T * __restrict__ dY,const T * __restrict__ X,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * __restrict__ gamma,T * dX,const int N,T_ACC * buf)314 __device__ __inline__ void compute_gI(
315 const T* __restrict__ dY,
316 const T* __restrict__ X,
317 const T_ACC* __restrict__ mean,
318 const T_ACC* __restrict__ rstd,
319 const T* __restrict__ gamma,
320 T* dX,
321 const int N,
322 T_ACC * buf){
323 const auto i1 = blockIdx.x;
324 const T_ACC mean_val = mean[i1];
325 const T_ACC rstd_val = rstd[i1];
326 T_ACC stats_x1{0}, stats_x2{0};
327 constexpr int unroll = 4;
328 auto l = unroll * threadIdx.x;
329 const T * X_i = X + i1 * N;
330 const T * dY_i = dY + i1 * N;
331 T * dX_i = dX + i1 * N;
332 //vectorized reads don't improve perf, so use regular unrolling
333
334 for (; l+unroll - 1 < N; l += blockDim.x * unroll){
335 #pragma unroll
336 for (int k=0; k< unroll; k++){
337 const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l+k]) : T_ACC(1);
338 const auto c_h = static_cast<T_ACC>(X_i[l+k]);
339 const auto c_loss = static_cast<T_ACC>(dY_i[l+k]);
340 stats_x1 += c_loss * gamma_val;
341 stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
342 }
343 }
344 for (; l < N; l ++) {
345 const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
346 const auto c_h = static_cast<T_ACC>(X_i[l]);
347 const auto c_loss = static_cast<T_ACC>(dY_i[l]);
348 stats_x1 += c_loss * gamma_val;
349 stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
350 }
351
352 stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf);
353 stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf);
354 if (threadIdx.x == 0) {
355 buf[0] = stats_x1;
356 buf[1] = stats_x2;
357 }
358 __syncthreads();
359 stats_x1 = buf[0];
360 stats_x2 = buf[1];
361 T_ACC fH = N;
362 T_ACC term1 = (T_ACC(1) / fH) * rstd_val;
363
364 for (int l = threadIdx.x; l < N; l += blockDim.x){
365 const auto x = X_i[l];
366 const auto dy = dY_i[l];
367 const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
368
369 T_ACC f_grad_input = fH * gamma_val * dy;
370 f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
371 f_grad_input -= stats_x1;
372 f_grad_input *= term1;
373 dX_i[l] = f_grad_input;
374 }
375 }
376
377
378 template<typename T, typename T_ACC>
layer_norm_grad_input_kernel(const T * __restrict__ dY,const T * __restrict__ X,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * __restrict__ gamma,T * dX,const int N)379 __global__ void layer_norm_grad_input_kernel(
380 const T* __restrict__ dY,
381 const T* __restrict__ X,
382 const T_ACC* __restrict__ mean,
383 const T_ACC* __restrict__ rstd,
384 const T* __restrict__ gamma,
385 T* dX,
386 const int N){
387 alignas(sizeof(double)) extern __shared__ char s_data1[];
388 T_ACC * buf = reinterpret_cast<T_ACC*>(&s_data1);
389
390 compute_gI(dY, X, mean, rstd, gamma, dX, N, buf);
391 }
392
393
394 // This implementation gets called when input buffers (dY, X, gamma and dX) are aligned
395 // to vec_size * sizeof(T). Compared to the unvectorized implementation, it is about 10%
396 // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M).
397 // There are no noticeable regressions on the rest of the sizes.
398
399 template<typename T, typename T_ACC>
layer_norm_grad_input_kernel_vectorized(const T * __restrict__ dY,const T * __restrict__ X,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * __restrict__ gamma,T * dX,const int N)400 __global__ void layer_norm_grad_input_kernel_vectorized(
401 const T* __restrict__ dY,
402 const T* __restrict__ X,
403 const T_ACC* __restrict__ mean,
404 const T_ACC* __restrict__ rstd,
405 const T* __restrict__ gamma,
406 T* dX,
407 const int N) {
408 alignas(sizeof(double)) extern __shared__ char shared_data[];
409 T_ACC* reduce_buf = reinterpret_cast<T_ACC*>(&shared_data);
410
411 const auto bIdx = blockIdx.x;
412 const T_ACC mean_val = mean[bIdx];
413 const T_ACC rstd_val = rstd[bIdx];
414 const T* X_i = X + bIdx * N;
415 const T* dY_i = dY + bIdx * N;
416 T* dX_i = dX + bIdx * N;
417
418 using vec_t = aligned_vector<T, vec_size>;
419 const vec_t* const X_i_vec_ptr = reinterpret_cast<const vec_t*>(X_i);
420 const vec_t* const dY_i_vec_ptr = reinterpret_cast<const vec_t*>(dY_i);
421 const vec_t* const gamma_vec_ptr = (gamma != nullptr) ? reinterpret_cast<const vec_t*>(gamma) : nullptr;
422 vec_t* const dX_i_vec = reinterpret_cast<vec_t*>(dX_i);
423
424 vec_t X_i_vec_reg, dY_i_vec_reg, gamma_vec_reg, dX_i_vec_reg;
425 for (int k = 0; k < vec_size; ++k) {
426 gamma_vec_reg.val[k] = T(1);
427 }
428
429 T_ACC stats_x1{0}, stats_x2{0};
430 unsigned int l = threadIdx.x * vec_size;
431 for (; l + vec_size - 1 < N; l += blockDim.x * vec_size) {
432 unsigned int vec_idx = l / vec_size;
433 if (gamma != nullptr) {
434 gamma_vec_reg = gamma_vec_ptr[vec_idx];
435 }
436
437 X_i_vec_reg = X_i_vec_ptr[vec_idx];
438 dY_i_vec_reg = dY_i_vec_ptr[vec_idx];
439
440 for (int k = 0; k < vec_size; ++k) {
441 const auto gamma_val = static_cast<T_ACC>(gamma_vec_reg.val[k]);
442 const auto c_h = static_cast<T_ACC>(X_i_vec_reg.val[k]);
443 const auto c_loss = static_cast<T_ACC>(dY_i_vec_reg.val[k]);
444 stats_x1 += c_loss * gamma_val;
445 stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
446 }
447 }
448
449 // Tail Loop
450 for (; l < N; l++) {
451 const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
452 const auto c_h = static_cast<T_ACC>(X_i[l]);
453 const auto c_loss = static_cast<T_ACC>(dY_i[l]);
454 stats_x1 += c_loss * gamma_val;
455 stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
456 }
457
458 // Reduction in Shared Memory
459 stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf);
460 stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf);
461 if (threadIdx.x == 0) {
462 reduce_buf[0] = stats_x1;
463 reduce_buf[1] = stats_x2;
464 }
465 __syncthreads();
466 stats_x1 = reduce_buf[0];
467 stats_x2 = reduce_buf[1];
468
469 T_ACC fH = N;
470 T_ACC term1 = (T_ACC(1) / fH) * rstd_val;
471
472 l = threadIdx.x * vec_size;
473 for (; l + vec_size - 1 < N; l += blockDim.x * vec_size) {
474 unsigned int vec_idx = l / vec_size;
475 if (gamma != nullptr) {
476 gamma_vec_reg = gamma_vec_ptr[vec_idx];
477 }
478
479 X_i_vec_reg = X_i_vec_ptr[vec_idx];
480 dY_i_vec_reg = dY_i_vec_ptr[vec_idx];
481
482 for (int k = 0; k < vec_size; ++k) {
483 const auto gamma_val = static_cast<T_ACC>(gamma_vec_reg.val[k]);
484 const auto x = static_cast<T_ACC>(X_i_vec_reg.val[k]);
485 const auto dy = static_cast<T_ACC>(dY_i_vec_reg.val[k]);
486
487 T_ACC f_grad_input = fH * gamma_val * dy;
488 f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
489 f_grad_input -= stats_x1;
490 f_grad_input *= term1;
491 dX_i_vec_reg.val[k] = f_grad_input;
492 }
493
494 dX_i_vec[vec_idx] = dX_i_vec_reg;
495 }
496
497 // Tail Loop
498 for (; l < N; l += blockDim.x) {
499 const auto x = X_i[l];
500 const auto dy = dY_i[l];
501 const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
502
503 T_ACC f_grad_input = fH * gamma_val * dy;
504 f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
505 f_grad_input -= stats_x1;
506 f_grad_input *= term1;
507 dX_i[l] = f_grad_input;
508 }
509 }
510
511
512 template <typename T, typename T_ACC>
GammaBetaBackwardSimpleCUDAKernel(int64_t M,int64_t N,const T * dY,const T * X,const T_ACC * mean,const T_ACC * rstd,T * dg,T * db)513 __global__ void GammaBetaBackwardSimpleCUDAKernel(
514 int64_t M,
515 int64_t N,
516 const T* dY,
517 const T* X,
518 const T_ACC* mean,
519 const T_ACC* rstd,
520 T* dg,
521 T* db) {
522 const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
523 if (j < N) {
524 T_ACC sum1 = 0;
525 T_ACC sum2 = 0;
526 for (int64_t i = 0; i < M; ++i) {
527 const int64_t index = i * N + j;
528 sum1 += dg == nullptr ? T_ACC(0)
529 : static_cast<T_ACC>(dY[index]) *
530 (static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) *
531 static_cast<T_ACC>(rstd[i]);
532 sum2 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index]);
533 }
534 if (dg != nullptr) {
535 dg[j] = sum1;
536 }
537 if (db != nullptr) {
538 db[j] = sum2;
539 }
540 }
541 }
542
543 // This implementation gets called if M and N divide with 32. This case should
544 // be the most common. We can then make better use of warp level intrinsics
545 // to improve performance.
546
547 template <typename T, typename T_ACC>
GammaBetaBackwardCUDAKernel_32x32(int64_t M,int64_t N,const T * dY,const T * X,const T_ACC * mean,const T_ACC * rstd,T * dg,T * db)548 __global__ void GammaBetaBackwardCUDAKernel_32x32(
549 int64_t M,
550 int64_t N,
551 const T* dY,
552 const T* X,
553 const T_ACC* mean,
554 const T_ACC* rstd,
555 T* dg,
556 T* db) {
557 alignas(sizeof(double)) extern __shared__ char s_data1[];
558 T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
559 T_ACC* s_dg;
560 T_ACC* s_db;
561
562 T_ACC dg_sum = 0;
563 T_ACC db_sum = 0;
564
565 const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
566
567 if (j < N) {
568 constexpr int unroll_factor = 8;
569 int laneId = threadIdx.x & (C10_WARP_SIZE - 1);
570
571 T_ACC mean_reg, mean_reg_tmp;
572 T_ACC rstd_reg, rstd_reg_tmp;
573 T dY_reg;
574 T X_reg;
575
576 // Main loop
577 int bcounter;
578 for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
579 bcounter++) {
580 int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
581
582 if (laneId < unroll_factor) {
583 mean_reg_tmp = mean[offset + laneId];
584 rstd_reg_tmp = rstd[offset + laneId];
585 }
586 WARP_SYNC();
587
588 #pragma unroll
589 for (int ii = 0; ii < unroll_factor; ++ii) {
590 dY_reg = dY[(offset + ii) * N + j];
591 X_reg = X[(offset + ii) * N + j];
592 mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
593 rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
594 dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
595 db_sum += dY_reg;
596 }
597 }
598
599 // Remainder loop
600 int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
601 for (int ii = 0; ii < unroll_factor; ii++) {
602 if ((offset + ii) < M) {
603 mean_reg = mean[offset + ii];
604 rstd_reg = rstd[offset + ii];
605 dY_reg = dY[(offset + ii) * N + j];
606 X_reg = X[(offset + ii) * N + j];
607 dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
608 db_sum += dY_reg;
609 }
610 }
611
612 // This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and
613 // gets called when M; N divide by 32. We can use warp shuffles
614 // for the final reduction step. This removes 4 shmem loads and
615 // stores with their corresponding __syncthreads()
616
617 // This greatly reduces bank conflicts at the expense of a little
618 // extra shared memory. It does not impact occupancy
619 int padded_bx = (1 + blockDim.x);
620
621 s_dg = s_data_typed;
622 s_db = s_data_typed + (padded_bx * blockDim.y);
623 s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
624 s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
625 __syncthreads();
626
627 // Load transposed so that a warp holds an entire column
628 T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
629 T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
630 for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) {
631 reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
632 reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
633 }
634
635 if (threadIdx.x == 0) {
636 const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
637 if (dg) {
638 dg[j] = reg_dg;
639 }
640 if (db) {
641 db[j] = reg_db;
642 }
643 }
644 }
645 }
646
647 template <typename T, typename T_ACC>
GammaBetaBackwardCUDAKernel(int64_t M,int64_t N,const T * dY,const T * X,const T_ACC * mean,const T_ACC * rstd,T * dg,T * db)648 __global__ void GammaBetaBackwardCUDAKernel(
649 int64_t M,
650 int64_t N,
651 const T* dY,
652 const T* X,
653 const T_ACC* mean,
654 const T_ACC* rstd,
655 T* dg,
656 T* db) {
657 alignas(sizeof(double)) extern __shared__ char s_data1[];
658 T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
659 T_ACC* s_dg;
660 T_ACC* s_db;
661
662 const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
663
664 T_ACC dg_sum = 0;
665 T_ACC db_sum = 0;
666
667 if (j < N) {
668 constexpr int unroll_factor = 8;
669
670 T_ACC mean_reg;
671 T_ACC rstd_reg;
672 T dY_reg;
673 T X_reg;
674
675 // Main Loop
676 int bcounter;
677 for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
678 int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
679
680 #pragma unroll
681 for (int ii = 0; ii < unroll_factor; ++ii) {
682 dY_reg = dY[(offset + ii) * N + j];
683 X_reg = X[(offset + ii) * N + j];
684 mean_reg = mean[offset + ii];
685 rstd_reg = rstd[offset + ii];
686 dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
687 db_sum += dY_reg;
688 }
689 }
690
691 // Remainder loop
692 int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
693 for (int ii = 0; ii < unroll_factor; ii++ ){
694 if ((offset + ii) < M) {
695 dY_reg = dY[(offset + ii) * N + j ];
696 X_reg = X[(offset + ii) * N + j];
697 mean_reg = mean[offset + ii];
698 rstd_reg = rstd[offset + ii];
699 dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
700 db_sum += dY_reg;
701 }
702 }
703
704 // Do the final reduction in shared memory
705 s_dg = s_data_typed;
706 s_db = s_data_typed + blockDim.x * blockDim.y;
707 s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
708 s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
709 __syncthreads();
710
711 for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
712 if (threadIdx.y < offset) {
713 s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
714 s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
715 s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
716 s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
717 }
718 __syncthreads();
719 }
720
721 if (threadIdx.y == 0) {
722 if (dg) {
723 dg[j] = s_dg[threadIdx.x];
724 }
725 if (db) {
726 db[j] = s_db[threadIdx.x];
727 }
728 }
729 }
730 }
731
732 template <typename T, typename T_ACC>
launch_vectorized_layer_norm_kernel(int N,int64_t M,T_ACC eps,const T * X_data,const T * gamma_data,const T * beta_data,T * Y_data,T_ACC * mean_data,T_ACC * rstd_data)733 void launch_vectorized_layer_norm_kernel(
734 int N,
735 int64_t M,
736 T_ACC eps,
737 const T* X_data,
738 const T* gamma_data,
739 const T* beta_data,
740 T* Y_data,
741 T_ACC* mean_data,
742 T_ACC* rstd_data
743 ) {
744 //constexpr int alignment = 16; //currently unused to make sure float and half results are bw accurate
745 auto stream = at::cuda::getCurrentCUDAStream().stream();
746 const int warp_size = at::cuda::warp_size();
747 const dim3 threads(warp_size, num_threads() / warp_size, 1);
748 const dim3 blocks(M);
749 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1);
750 int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0;
751 vectorized_layer_norm_kernel<<<blocks, threads, nshared, stream>>>(N, eps, X_data,
752 gamma_data, beta_data, mean_data, rstd_data, Y_data);
753 C10_CUDA_KERNEL_LAUNCH_CHECK();
754 }
755
756 template <typename T, typename T_ACC>
LayerNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,T_ACC eps,Tensor * Y,Tensor * mean,Tensor * rstd)757 void LayerNormKernelImplInternal(
758 const Tensor& X,
759 const Tensor& gamma,
760 const Tensor& beta,
761 int64_t M,
762 int64_t N,
763 T_ACC eps,
764 Tensor* Y,
765 Tensor* mean,
766 Tensor* rstd) {
767 // assumes input, gamma and beta are of proper shape, this was checked in _check_layer_norm_inputs
768 // assumes all tensors are contiguous
769 TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \
770 file a support request to support bigger batches");
771 const T* X_data = X.const_data_ptr<T>();
772 const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
773 const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
774 T* Y_data = Y->data_ptr<T>();
775 T_ACC* mean_data = mean->data_ptr<T_ACC>();
776 T_ACC* rstd_data = rstd->data_ptr<T_ACC>();
777
778 // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count),
779 // N is multiple of vec_size (so that all rows are aligned if tensor is aligned)
780 constexpr int num_vec_elems = vec_size;
781 constexpr int alignment = num_vec_elems * sizeof(T);
782 bool can_vec_X = can_vectorize(X_data, alignment);
783 bool can_vec_Y = can_vectorize(Y_data, alignment);
784 bool can_vec_gamma = gamma.defined() ? can_vectorize(gamma_data, alignment) : true;
785 bool can_vec_beta = beta.defined() ? can_vectorize(beta_data, alignment) : true;
786
787 if ((std::is_same<T, float>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) &&
788 N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && N % num_vec_elems == 0 &&
789 can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) {
790 launch_vectorized_layer_norm_kernel(static_cast<int>(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data);
791 } else {
792 cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
793 RowwiseMomentsCUDAKernel<T, T_ACC>
794 <<<M, cuda_utils::kCUDABlockReduceNumThreads, 0, cuda_stream>>>(
795 N, eps, X_data, mean_data, rstd_data);
796 C10_CUDA_KERNEL_LAUNCH_CHECK();
797 LayerNormForwardCUDAKernel<T, T_ACC><<<M, kCUDANumThreads, 0, cuda_stream>>>(
798 N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data);
799 C10_CUDA_KERNEL_LAUNCH_CHECK();
800 }
801 }
802
LayerNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,double eps,Tensor * Y,Tensor * mean,Tensor * rstd)803 void LayerNormKernelImpl(
804 const Tensor& X,
805 const Tensor& gamma,
806 const Tensor& beta,
807 int64_t M,
808 int64_t N,
809 double eps,
810 Tensor* Y,
811 Tensor* mean,
812 Tensor* rstd) {
813 AT_DISPATCH_FLOATING_TYPES_AND2(
814 at::ScalarType::Half,
815 at::ScalarType::BFloat16,
816 X.scalar_type(),
817 "LayerNormKernelImpl",
818 [&]() {
819 using acc_t = acc_type<scalar_t, true>;
820 LayerNormKernelImplInternal<scalar_t, acc_t>(
821 X, gamma, beta, M, N, static_cast<acc_t>(eps), Y, mean, rstd);
822 });
823 }
824
825 template<typename T, typename T_ACC> __device__
cuLoadWriteStridedInputs(const int i1_block,const int thr_load_row_off,const int thr_load_col_off,const int i2_off,const int row_stride,T_ACC * warp_buf1,T_ACC * warp_buf2,const T * input,const T * dout,const int i1_end,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd)826 void cuLoadWriteStridedInputs(
827 const int i1_block,
828 const int thr_load_row_off,
829 const int thr_load_col_off,
830 const int i2_off,
831 const int row_stride,
832 T_ACC* warp_buf1,
833 T_ACC* warp_buf2,
834 const T* input,
835 const T* dout,
836 const int i1_end,
837 const int64_t N,
838 const T_ACC* __restrict__ mean,
839 const T_ACC* __restrict__ rstd)
840 {
841 int i1 = i1_block+thr_load_row_off;
842 if (i1 < i1_end) {
843 T curr_mean = mean[i1];
844 T curr_rstd = rstd[i1];
845 for (int k = 0; k < blockDim.y; ++k) {
846 int i2 = i2_off + k;
847 int load_idx = i1*N+i2;
848 int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
849 if (i2<N) {
850 T curr_input = static_cast<T>(input[load_idx]);
851 T curr_dout = static_cast<T>(dout[load_idx]);
852 warp_buf1[write_idx] = curr_dout;
853 warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_rstd;
854 } else {
855 warp_buf1[write_idx] = T(0);
856 warp_buf2[write_idx] = T(0);
857 }
858 }
859 } else {
860 for (int k = 0; k < blockDim.y; ++k) {
861 int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
862 warp_buf1[write_idx] = T(0);
863 warp_buf2[write_idx] = T(0);
864 }
865 }
866 }
867
868 template<typename T, typename T_ACC> __device__
cuLoadAddStridedInputs(const int i1_block,const int thr_load_row_off,const int thr_load_col_off,const int i2_off,const int row_stride,T_ACC * warp_buf1,T_ACC * warp_buf2,const T * input,const T * dout,const int i1_end,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd)869 void cuLoadAddStridedInputs(
870 const int i1_block,
871 const int thr_load_row_off,
872 const int thr_load_col_off,
873 const int i2_off,
874 const int row_stride,
875 T_ACC* warp_buf1,
876 T_ACC* warp_buf2,
877 const T* input,
878 const T* dout,
879 const int i1_end,
880 const int64_t N,
881 const T_ACC* __restrict__ mean,
882 const T_ACC* __restrict__ rstd)
883 {
884 int i1 = i1_block+thr_load_row_off;
885 if (i1 < i1_end) {
886 T_ACC curr_mean = mean[i1];
887 T_ACC curr_rstd = rstd[i1];
888 for (int k = 0; k < blockDim.y; ++k) {
889 int i2 = i2_off + k;
890 int load_idx = i1*N+i2;
891 int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
892 if (i2<N) {
893 T_ACC curr_input = static_cast<T_ACC>(input[load_idx]);
894 T_ACC curr_dout = static_cast<T_ACC>(dout[load_idx]);
895 warp_buf1[write_idx] += curr_dout;
896 warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_rstd;
897 }
898 }
899 }
900 }
901
902 template<typename T, typename T_ACC> __global__
cuComputePartGradGammaBeta(const T * __restrict__ dout,const T * __restrict__ input,const int64_t M,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,T_ACC * part_grad_gamma,T_ACC * part_grad_beta)903 void cuComputePartGradGammaBeta(
904 const T* __restrict__ dout,
905 const T* __restrict__ input,
906 const int64_t M,
907 const int64_t N,
908 const T_ACC* __restrict__ mean,
909 const T_ACC* __restrict__ rstd,
910 T_ACC* part_grad_gamma,
911 T_ACC* part_grad_beta)
912 {
913 const int numsegs_M = (M+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
914 const int segs_per_block = (numsegs_M + gridDim.y - 1) / gridDim.y;
915 const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
916 const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
917 const int i1_end = i1_beg_plus_one < M ? i1_beg_plus_one : M;
918 const int row_stride = blockDim.x+1;
919 const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
920 const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
921 const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
922 alignas(sizeof(double)) extern __shared__ char shared[];
923 T_ACC * buf = reinterpret_cast<T_ACC*>(&shared); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
924 T_ACC* warp_buf1 = (T_ACC*)buf;
925 T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
926 // compute partial sums from strided inputs
927 // do this to increase number of loads in flight
928 cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
929 for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
930 cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
931 }
932 __syncthreads();
933 // inter-warp reductions
934 // sum within each warp
935 T_ACC acc1 = T_ACC(0);
936 T_ACC acc2 = T_ACC(0);
937 for (int k = 0; k < blockDim.y; ++k) {
938 int row1 = threadIdx.y + k*blockDim.y;
939 int idx1 = row1*row_stride + threadIdx.x;
940 acc1 += warp_buf1[idx1];
941 acc2 += warp_buf2[idx1];
942 }
943 warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
944 warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
945 __syncthreads();
946 // sum all warps
947 for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
948 if (threadIdx.y < offset) {
949 int row1 = threadIdx.y;
950 int row2 = threadIdx.y + offset;
951 int idx1 = row1*row_stride + threadIdx.x;
952 int idx2 = row2*row_stride + threadIdx.x;
953 warp_buf1[idx1] += warp_buf1[idx2];
954 warp_buf2[idx1] += warp_buf2[idx2];
955 }
956 __syncthreads();
957 }
958 int i2 = blockIdx.x * blockDim.x + threadIdx.x;
959 if (threadIdx.y == 0 && i2 < N) {
960 int row1 = threadIdx.y;
961 int row2 = threadIdx.y + 1;
962 int idx1 = row1*row_stride + threadIdx.x;
963 int idx2 = row2*row_stride + threadIdx.x;
964 part_grad_beta[blockIdx.y*N+i2] = warp_buf1[idx1] + warp_buf1[idx2];
965 part_grad_gamma[blockIdx.y*N+i2] = warp_buf2[idx1] + warp_buf2[idx2];
966 }
967 }
968
969 template<typename T, typename T_ACC> __global__
cuComputeGradGammaBeta(const T_ACC * part_grad_gamma,const T_ACC * part_grad_beta,const int part_size,const int64_t M,const int64_t N,T * grad_gamma,T * grad_beta)970 void cuComputeGradGammaBeta(
971 const T_ACC* part_grad_gamma,
972 const T_ACC* part_grad_beta,
973 const int part_size,
974 const int64_t M,
975 const int64_t N,
976 T* grad_gamma,
977 T* grad_beta)
978 {
979 // sum partial gradients for gamma and beta
980 alignas(sizeof(double)) extern __shared__ char shared[];
981 T_ACC * buf = reinterpret_cast<T_ACC*>(&shared);
982 int i2 = blockIdx.x * blockDim.x + threadIdx.x;
983
984 // each warp does sequential reductions until reduced part_size is num_warps
985 int num_warp_reductions = part_size / blockDim.y;
986 T_ACC sum_gamma = T_ACC(0);
987 T_ACC sum_beta = T_ACC(0);
988 const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * N + i2;
989 const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * N + i2;
990
991 if (i2 < N) {
992 for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
993 sum_gamma += part_grad_gamma_ptr[warp_offset*N];
994 sum_beta += part_grad_beta_ptr[warp_offset*N];
995 }
996 }
997
998 // inter-warp reductions
999 const int nbsize3 = blockDim.x * blockDim.y / 2;
1000 for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
1001 // top half write to shared memory
1002 if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
1003 const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
1004 buf[write_idx] = sum_gamma;
1005 buf[write_idx+nbsize3] = sum_beta;
1006 }
1007 __syncthreads();
1008 // bottom half sums
1009 if (threadIdx.y < offset) {
1010 const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
1011 sum_gamma += buf[read_idx];
1012 sum_beta += buf[read_idx+nbsize3];
1013 }
1014 __syncthreads();
1015 }
1016
1017 // write out fully summed gradients
1018 if (threadIdx.y == 0 && i2 < N) {
1019 if (grad_gamma) {
1020 grad_gamma[i2] = sum_gamma;
1021 }
1022 if (grad_beta) {
1023 grad_beta[i2] = sum_beta;
1024 }
1025 }
1026 }
1027
1028 template<typename T, typename T_ACC> __global__
cuComputeGradInput(const T * __restrict__ dout,const T * __restrict__ input,const int64_t M,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * gamma,T * grad_input)1029 void cuComputeGradInput(
1030 const T* __restrict__ dout,
1031 const T* __restrict__ input,
1032 const int64_t M,
1033 const int64_t N,
1034 const T_ACC* __restrict__ mean,
1035 const T_ACC* __restrict__ rstd,
1036 const T* gamma,
1037 T* grad_input)
1038 {
1039 for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) {
1040 T_ACC sum_loss1 = T_ACC(0);
1041 T_ACC sum_loss2 = T_ACC(0);
1042 T_ACC c_mean = mean[i1];
1043 const T_ACC c_rstd = rstd[i1];
1044 const T* k_input = input + i1*N;
1045 const T* k_dout = dout + i1*N;
1046 const int numx = blockDim.x * blockDim.y;
1047 const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
1048 if (gamma != NULL) {
1049 // Optimization for ROCm MI100
1050 for( int l = 0; l < N ; l += numx) {
1051 int idx = l + thrx;
1052 const T_ACC gamma_idx = static_cast<T_ACC>((idx<N) ? gamma[idx] : T(0));
1053 const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
1054 const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
1055 sum_loss1 += c_loss * gamma_idx;
1056 sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_rstd;
1057 }
1058 } else {
1059 for( int l = 0; l < N ; l += numx) {
1060 int idx = l + thrx;
1061 const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
1062 const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
1063 sum_loss1 += c_loss;
1064 sum_loss2 += c_loss * (c_h - c_mean) * c_rstd;
1065 }
1066 }
1067 // intra-warp reductions
1068 for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
1069 sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
1070 sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
1071 }
1072 // inter-warp reductions
1073 if (blockDim.y > 1) {
1074 alignas(sizeof(double)) extern __shared__ char shared[];
1075 T_ACC * buf = reinterpret_cast<T_ACC*>(&shared);
1076 for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
1077 // upper half of warps write to shared
1078 if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
1079 const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
1080 buf[2*wrt_i] = sum_loss1;
1081 buf[2*wrt_i+1] = sum_loss2;
1082 }
1083 __syncthreads();
1084 // lower half merges
1085 if (threadIdx.y < offset) {
1086 const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
1087 sum_loss1 += buf[2*read_i];
1088 sum_loss2 += buf[2*read_i+1];
1089 }
1090 __syncthreads();
1091 }
1092 if (threadIdx.y == 0) {
1093 buf[2*threadIdx.x] = sum_loss1;
1094 buf[2*threadIdx.x+1] = sum_loss2;
1095 }
1096 __syncthreads();
1097 if (threadIdx.y !=0) {
1098 sum_loss1 = buf[2*threadIdx.x];
1099 sum_loss2 = buf[2*threadIdx.x+1];
1100 }
1101 }
1102 // all threads now have the two sums over l
1103 T_ACC fH = (T_ACC)N;
1104 T_ACC term1 = (T_ACC(1) / fH) * c_rstd;
1105 T* k_grad_input = grad_input + i1*N;
1106 if (gamma != NULL) {
1107 for (int l = thrx; l < N; l+=numx) {
1108 const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
1109 const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
1110 T_ACC f_grad_input = fH * c_loss * gamma[l];
1111 f_grad_input -= sum_loss1;
1112 f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
1113 f_grad_input *= term1;
1114 k_grad_input[l] = static_cast<T>(f_grad_input);
1115 }
1116 } else {
1117 for (int l = thrx; l < N; l+=numx) {
1118 const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
1119 const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
1120 T_ACC f_grad_input = fH * c_loss;
1121 f_grad_input -= sum_loss1;
1122 f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
1123 f_grad_input *= term1;
1124 k_grad_input[l] = static_cast<T>(f_grad_input);
1125 }
1126 }
1127 // prevent race where buf is written again before reads are done
1128 __syncthreads();
1129 }
1130 }
1131
1132 template <typename T>
LayerNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)1133 void LayerNormBackwardKernelImplInternal(
1134 const Tensor& dY,
1135 const Tensor& X,
1136 const Tensor& mean,
1137 const Tensor& rstd,
1138 const Tensor& gamma,
1139 int64_t M,
1140 int64_t N,
1141 Tensor* dX,
1142 Tensor* dgamma,
1143 Tensor* dbeta) {
1144 using T_ACC = acc_type<T, true>;
1145 TORCH_CHECK(dY.numel() == M * N);
1146 TORCH_CHECK(mean.numel() == M);
1147 TORCH_CHECK(rstd.numel() == M);
1148 TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \
1149 file a support request to support bigger batches");
1150 TORCH_CHECK(N <= std::numeric_limits<int>::max(), "Normalized shape should have less than INT_MAX elements, \
1151 file a support request to support bigger normalized shapes");
1152 const T* dY_data = dY.template const_data_ptr<T>();
1153 const T* X_data = X.template const_data_ptr<T>();
1154 const T_ACC* mean_data = mean.template const_data_ptr<T_ACC>();
1155 const T_ACC* rstd_data = rstd.template const_data_ptr<T_ACC>();
1156 const T* gamma_data =
1157 gamma.defined() ? gamma.template const_data_ptr<T>() : nullptr;
1158 T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
1159 cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
1160 const int warp_size = at::cuda::warp_size();
1161 if (dX_data != nullptr) {
1162 #ifdef USE_ROCM
1163 if (M >= 32768) {
1164 const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
1165 const dim3 blocks1(1, std::min((uint64_t)M, maxGridY), 1);
1166 dim3 threads1(warp_size, 4, 1);
1167 threads1.y = 2; // Optimization for ROCm
1168 int nshared =
1169 threads1.y > 1 ?
1170 threads1.y*threads1.x*sizeof(T_ACC) :
1171 0;
1172 cuComputeGradInput<<<blocks1, threads1, nshared, cuda_stream>>>(
1173 dY_data,
1174 X_data,
1175 M, N,
1176 mean_data,
1177 rstd_data,
1178 gamma_data,
1179 dX_data);
1180 C10_CUDA_KERNEL_LAUNCH_CHECK();
1181 } else {
1182 const dim3 blocks(M);
1183 int nshared = (num_threads()/warp_size) * sizeof(T_ACC);
1184 layer_norm_grad_input_kernel<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
1185 X_data, mean_data, rstd_data, gamma_data, dX_data, N);
1186 C10_CUDA_KERNEL_LAUNCH_CHECK();
1187 }
1188 #else
1189 const dim3 blocks(M);
1190 int nshared = (num_threads() / warp_size) * sizeof(T_ACC);
1191
1192 bool bVectorSizeMultiple = (N % vec_size == 0);
1193 bool bTargetDataTypes = (std::is_same<T, float>::value || std::is_same<T, at::Half>::value ||
1194 std::is_same<T, at::BFloat16>::value);
1195 const unsigned int alignment = sizeof(T) * vec_size;
1196 bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) &&
1197 can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment);
1198
1199 if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) {
1200 layer_norm_grad_input_kernel_vectorized<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
1201 X_data, mean_data, rstd_data, gamma_data, dX_data, N);
1202 C10_CUDA_KERNEL_LAUNCH_CHECK();
1203 } else {
1204 layer_norm_grad_input_kernel<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
1205 X_data, mean_data, rstd_data, gamma_data, dX_data, N);
1206 C10_CUDA_KERNEL_LAUNCH_CHECK();
1207 }
1208 #endif
1209 }
1210
1211 if (dgamma->defined() || dbeta->defined()) {
1212 T* dgamma_data =
1213 dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
1214 T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
1215
1216 if (M < 128) {
1217 // For small batch size, do colwise reduce directly.
1218 const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
1219 GammaBetaBackwardSimpleCUDAKernel<T, T_ACC>
1220 <<<B, kCUDANumThreads, 0, cuda_stream>>>(
1221 M,
1222 N,
1223 dY_data,
1224 X_data,
1225 mean_data,
1226 rstd_data,
1227 dgamma_data,
1228 dbeta_data);
1229 C10_CUDA_KERNEL_LAUNCH_CHECK();
1230 } else {
1231 #if defined(USE_ROCM)
1232 // For small batch size, do colwise reduce directly.
1233 const int part_size = warp_size;
1234 const dim3 threads2(warp_size, 4, 1);
1235 const dim3 blocks2((N + threads2.x - 1) / threads2.x, part_size, 1);
1236 const int nshared2_a = 2 * sizeof(T_ACC) * threads2.y * threads2.y * (threads2.x + 1);
1237 const int nshared2_b = threads2.x * threads2.y * sizeof(T_ACC);
1238 const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
1239
1240 const auto part_grad_dtype = at::toAccumulateType(X.scalar_type(), true);
1241 Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype));
1242 Tensor part_grad_beta = at::native::empty_like(part_grad_gamma);
1243
1244 cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, cuda_stream>>>(
1245 dY_data,
1246 X_data,
1247 M,N,
1248 mean_data,
1249 rstd_data,
1250 part_grad_gamma.template data_ptr<T_ACC>(),
1251 part_grad_beta.template data_ptr<T_ACC>());
1252 C10_CUDA_KERNEL_LAUNCH_CHECK();
1253
1254 const dim3 threads3(warp_size, 8, 1); // Optimization for ROCm
1255 const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1);
1256 const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC);
1257
1258 cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, cuda_stream>>>(
1259 part_grad_gamma.template data_ptr<T_ACC>(),
1260 part_grad_beta.template data_ptr<T_ACC>(),
1261 part_size,
1262 M,N,
1263 dgamma_data,
1264 dbeta_data);
1265 C10_CUDA_KERNEL_LAUNCH_CHECK();
1266 #else
1267 if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
1268 // This implementation relies on warp primitives and requires that M and N divide
1269 // exactly to warp size.
1270 dim3 threads{kWarpSize, kWarpSize};
1271 int blocks = (N + threads.x - 1) / threads.x;
1272
1273 // If M and N divide by warp_size, we can use warp shuffles for the final reduction.
1274 // That requires transposing values in shared memory, so we apply a padding to
1275 // reduce bank conflicts.
1276
1277 size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
1278 GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
1279 <<<blocks, threads, shmem_sz, cuda_stream>>>(
1280 M,
1281 N,
1282 dY_data,
1283 X_data,
1284 mean_data,
1285 rstd_data,
1286 dgamma_data,
1287 dbeta_data);
1288 C10_CUDA_KERNEL_LAUNCH_CHECK();
1289 } else {
1290 dim3 threads{16, 32};
1291 int blocks = (N + threads.x - 1) / threads.x;
1292 size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
1293 GammaBetaBackwardCUDAKernel<T, T_ACC>
1294 <<<blocks, threads, shmem_sz, cuda_stream>>>(
1295 M,
1296 N,
1297 dY_data,
1298 X_data,
1299 mean_data,
1300 rstd_data,
1301 dgamma_data,
1302 dbeta_data);
1303 C10_CUDA_KERNEL_LAUNCH_CHECK();
1304 }
1305 #endif
1306 }
1307 }
1308 }
1309
LayerNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)1310 void LayerNormBackwardKernelImpl(
1311 const Tensor& dY,
1312 const Tensor& X,
1313 const Tensor& mean,
1314 const Tensor& rstd,
1315 const Tensor& gamma,
1316 int64_t M,
1317 int64_t N,
1318 Tensor* dX,
1319 Tensor* dgamma,
1320 Tensor* dbeta) {
1321 AT_DISPATCH_FLOATING_TYPES_AND2(
1322 at::ScalarType::Half,
1323 at::ScalarType::BFloat16,
1324 X.scalar_type(),
1325 "LayerNormBackwardKernelImpl",
1326 [&]() {
1327 LayerNormBackwardKernelImplInternal<scalar_t>(
1328 dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
1329 });
1330 }
1331
1332 } // namespace
1333
layer_norm_cuda(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)1334 std::tuple<Tensor, Tensor, Tensor> layer_norm_cuda(
1335 const Tensor& input,
1336 IntArrayRef normalized_shape,
1337 const std::optional<Tensor>& weight_opt /* optional */,
1338 const std::optional<Tensor>& bias_opt /* optional */,
1339 double eps) {
1340 // See [Note: hacky wrapper removal for optional tensor]
1341 c10::MaybeOwned<Tensor> weight_maybe_owned =
1342 at::borrow_from_optional_tensor(weight_opt);
1343 const Tensor& weight = *weight_maybe_owned;
1344 c10::MaybeOwned<Tensor> bias_maybe_owned =
1345 at::borrow_from_optional_tensor(bias_opt);
1346 const Tensor& bias = *bias_maybe_owned;
1347
1348 auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
1349 auto M = M_N.first;
1350 auto N = M_N.second;
1351 auto X = input.expect_contiguous();
1352 auto gamma = weight.expect_contiguous();
1353 auto beta = bias.expect_contiguous();
1354
1355 Tensor Y = at::native::empty_like(
1356 *X,
1357 std::nullopt /* dtype */,
1358 std::nullopt /* layout */,
1359 std::nullopt /* device */,
1360 std::nullopt /* pin_memory */,
1361 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1362 auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
1363 Tensor mean = at::empty({M}, X->options().dtype(acc_type));
1364 Tensor rstd = at::empty({M}, X->options().dtype(acc_type));
1365 // Calling the kernel for M==0 gives a CUDA error
1366 // See: https://github.com/pytorch/pytorch/pull/28614
1367 if (M > 0) {
1368 LayerNormKernelImpl(*X, *gamma, *beta, M, N, eps, &Y, &mean, &rstd);
1369 }
1370 const auto input_shape = input.sizes();
1371 const size_t axis = input.dim() - normalized_shape.size();
1372
1373 std::vector<int64_t> stat_shape;
1374 for (const auto idx: c10::irange(axis)) {
1375 stat_shape.push_back(input_shape[idx]);
1376 }
1377 for (const auto C10_UNUSED idx: c10::irange(axis, input.dim())) {
1378 stat_shape.push_back(1);
1379 }
1380
1381 mean = mean.view(stat_shape);
1382 rstd = rstd.view(stat_shape);
1383
1384 return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
1385 }
1386
layer_norm_backward_cuda(const Tensor & dY,const Tensor & input,IntArrayRef normalized_shape,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,std::array<bool,3> grad_input_mask)1387 std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cuda(
1388 const Tensor& dY,
1389 const Tensor& input,
1390 IntArrayRef normalized_shape,
1391 const Tensor& mean,
1392 const Tensor& rstd,
1393 const std::optional<Tensor>& weight_opt /* optional */,
1394 const std::optional<Tensor>& bias_opt /* optional */,
1395 std::array<bool, 3> grad_input_mask) {
1396 // See [Note: hacky wrapper removal for optional tensor]
1397 c10::MaybeOwned<Tensor> weight_maybe_owned =
1398 at::borrow_from_optional_tensor(weight_opt);
1399 const Tensor& weight = *weight_maybe_owned;
1400 c10::MaybeOwned<Tensor> bias_maybe_owned =
1401 at::borrow_from_optional_tensor(bias_opt);
1402 const Tensor& bias = *bias_maybe_owned;
1403
1404 auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
1405 auto M = M_N.first;
1406 auto N = M_N.second;
1407 auto X = input.expect_contiguous();
1408 auto gamma = weight.expect_contiguous();
1409 auto beta = bias.expect_contiguous();
1410
1411 Tensor dX;
1412 Tensor dgamma;
1413 Tensor dbeta;
1414 if (grad_input_mask[0]) {
1415 dX = at::native::empty_like(
1416 *X,
1417 std::nullopt /* dtype */,
1418 std::nullopt /* layout */,
1419 std::nullopt /* device */,
1420 std::nullopt /* pin_memory */,
1421 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1422 }
1423 if (grad_input_mask[1]) {
1424 dgamma = M > 0 ? at::native::empty_like(
1425 *gamma,
1426 std::nullopt /* dtype */,
1427 std::nullopt /* layout */,
1428 std::nullopt /* device */,
1429 std::nullopt /* pin_memory */,
1430 LEGACY_CONTIGUOUS_MEMORY_FORMAT)
1431 : at::native::zeros_like(
1432 *gamma,
1433 std::nullopt /* dtype */,
1434 std::nullopt /* layout */,
1435 std::nullopt /* device */,
1436 std::nullopt /* pin_memory */,
1437 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1438 }
1439 if (grad_input_mask[2]) {
1440 dbeta = M > 0 ? at::native::empty_like(
1441 *beta,
1442 std::nullopt /* dtype */,
1443 std::nullopt /* layout */,
1444 std::nullopt /* device */,
1445 std::nullopt /* pin_memory */,
1446 LEGACY_CONTIGUOUS_MEMORY_FORMAT)
1447 : at::native::zeros_like(
1448 *beta,
1449 std::nullopt /* dtype */,
1450 std::nullopt /* layout */,
1451 std::nullopt /* device */,
1452 std::nullopt /* pin_memory */,
1453 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1454 }
1455 if (M > 0 && N > 0) {
1456 LayerNormBackwardKernelImpl(
1457 dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
1458 }
1459 return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
1460 }
1461
1462 REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl);
1463 REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl);
1464
1465 } // namespace at::native
1466