xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/layer_norm.h>
3 
4 #include <type_traits>
5 
6 #include <thrust/tuple.h>
7 
8 #include <ATen/core/Tensor.h>
9 #include <ATen/AccumulateType.h>
10 #include <ATen/Dispatch.h>
11 #include <ATen/cuda/CUDAContext.h>
12 #include <ATen/cuda/detail/IndexUtils.cuh>
13 #include <ATen/native/cuda/block_reduce.cuh>
14 #include <ATen/native/cuda/thread_constants.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/empty_like_native.h>
22 #include <ATen/ops/native_layer_norm_native.h>
23 #include <ATen/ops/native_layer_norm_backward_native.h>
24 #include <ATen/ops/zeros_like_native.h>
25 #endif
26 
27 #include <c10/cuda/CUDAMathCompat.h>
28 #include <c10/util/env.h>
29 
30 
31 namespace at::native {
32 
33 namespace {
34 
35 constexpr int kCUDANumThreads = 256;
36 constexpr unsigned int kWarpSize = C10_WARP_SIZE;
37 constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types
38 
39 // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh)
40 template<typename scalar_t, int vec_size>
41 struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
42   scalar_t val[vec_size];
43 };
44 
45 // Checks alignment of buffers for using vectorized loads / stores
46 template<typename T>
can_vectorize(const T * ptr,int alignment)47 bool can_vectorize(const T * ptr, int alignment) {
48   uint64_t addr = reinterpret_cast<uint64_t>(ptr);
49   return addr % alignment == 0;
50 };
51 
52 
53 template <typename T, typename T_ACC>
RowwiseMomentsCUDAKernel(int64_t N,T_ACC eps,const T * X,T_ACC * mean,T_ACC * rstd)54 __global__ void RowwiseMomentsCUDAKernel(
55     int64_t N,
56     T_ACC eps,
57     const T* X,
58     T_ACC* mean,
59     T_ACC* rstd) {
60   using WelfordType = WelfordData<T_ACC, int64_t>;
61   using WelfordOp =
62       WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
63 
64   __shared__
65       typename std::aligned_storage<sizeof(WelfordType), alignof(WelfordType)>::
66           type val_shared[C10_WARP_SIZE];
67   WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
68 
69   const int64_t i = blockIdx.x;
70   WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
71   WelfordType val(0, 0, 0, 0);
72 
73   for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
74     const int64_t index = i * N + j;
75     val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
76   }
77   val = cuda_utils::BlockReduce(
78       val,
79       welford_op,
80       /*identity_element=*/WelfordType(0, 0, 0, 0),
81       val_shared_ptr);
82 
83   if (threadIdx.x == 0) {
84     T_ACC m1;
85     T_ACC m2;
86     thrust::tie(m2, m1) = welford_op.project(val);
87     mean[i] = m1;
88     rstd[i] = c10::cuda::compat::rsqrt(m2 + eps);
89   }
90 }
91 
92 template <typename T, typename T_ACC>
LayerNormForwardCUDAKernel(int64_t N,const T * X,const T_ACC * mean,const T_ACC * rstd,const T * gamma,const T * beta,T * Y)93 __global__ void LayerNormForwardCUDAKernel(
94     int64_t N,
95     const T* X,
96     const T_ACC* mean,
97     const T_ACC* rstd,
98     const T* gamma,
99     const T* beta,
100     T* Y) {
101   const int64_t i = blockIdx.x;
102   for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
103     const int64_t index = i * N + j;
104     const T_ACC gamma_v =
105         gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]);
106     const T_ACC beta_v =
107         beta == nullptr ? T_ACC(0) : static_cast<T_ACC>(beta[j]);
108     Y[index] = (static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) *
109             static_cast<T_ACC>(rstd[i]) * gamma_v +
110         beta_v;
111   }
112 }
113 
114 struct WelfordDataLN{
115   float mean;
116   float sigma2;
117   float count;
WelfordDataLNat::native::__anonf65285a40111::WelfordDataLN118   C10_HOST_DEVICE WelfordDataLN(): mean(0.f), sigma2(0.f), count(0.f){}
WelfordDataLNat::native::__anonf65285a40111::WelfordDataLN119   C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {}
120 };
121 
122 template<typename U> __device__
cuWelfordOnlineSum(const U val,const WelfordDataLN & curr_sum)123 WelfordDataLN cuWelfordOnlineSum(
124   const U val,
125   const WelfordDataLN& curr_sum)
126 {
127   U delta = val - curr_sum.mean;
128   U new_count = curr_sum.count + 1.f;
129   U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
130   return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
131 }
132 
133 __device__
cuWelfordCombine(const WelfordDataLN dataB,const WelfordDataLN dataA)134 WelfordDataLN cuWelfordCombine(
135   const WelfordDataLN dataB,
136   const WelfordDataLN dataA
137 ) {
138   using U = decltype(dataB.count);
139   U delta = dataB.mean - dataA.mean;
140   U count = dataA.count + dataB.count;
141   U mean, sigma2;
142   if (count > decltype(dataB.count){0}) {
143     auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
144     auto nA = dataA.count * coef;
145     auto nB = dataB.count * coef;
146     mean = nA*dataA.mean + nB*dataB.mean;
147     sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB;
148   } else {
149     mean = U(0);
150     sigma2 = U(0);
151   }
152   return {mean, sigma2, count};
153 }
154 
155 template<typename T>
compute_stats(const T * __restrict__ X,const int N,float * buf)156 __device__ WelfordDataLN compute_stats(
157   const T*  __restrict__ X,
158   const int N,
159   float * buf
160   ) {
161     //X points to the row to read
162     using vec_t = aligned_vector<T, vec_size>;
163     using acc_t = acc_type<T, true>;
164     const vec_t * X_vec = reinterpret_cast<const vec_t*>(X);
165     const int numx = blockDim.x * blockDim.y;
166     const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
167     const int n_vec_to_read = N/vec_size;
168     WelfordDataLN wd(0.f, 0.f, 0.f);
169     //no tail, we check that N is multiple of vec_size
170     for (int i = thrx; i < n_vec_to_read; i += numx) {
171       vec_t data = X_vec[i];
172       #pragma unroll
173       for (int ii=0; ii < vec_size; ii++){
174         wd = cuWelfordOnlineSum(static_cast<acc_t>(data.val[ii]), wd);
175       }
176     }
177     // intra-warp reduction
178     for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
179         WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset),
180         WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)};
181         wd = cuWelfordCombine(wd, wdB);
182     }
183     // threadIdx.x == 0 has correct values for each warp
184     // inter-warp reductions
185     if (blockDim.y > 1) {
186       float * meansigmabuf = buf;
187       float * countbuf = buf + blockDim.y;
188       for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
189         // upper half of warps write to shared
190         if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
191           const int wrt_y = threadIdx.y - offset;
192           meansigmabuf[2*wrt_y] = wd.mean;
193           meansigmabuf[2*wrt_y+1] = wd.sigma2;
194           countbuf[wrt_y] = wd.count;
195         }
196         __syncthreads();
197         // lower half merges
198         if (threadIdx.x == 0 && threadIdx.y < offset) {
199           WelfordDataLN wdB{meansigmabuf[2*threadIdx.y],
200                           meansigmabuf[2*threadIdx.y+1],
201                           countbuf[threadIdx.y]};
202           wd = cuWelfordCombine(wd, wdB);
203         }
204         __syncthreads();
205       }
206       if (threadIdx.x == 0 && threadIdx.y ==0) {
207         meansigmabuf[0] = wd.mean;
208         meansigmabuf[1] = wd.sigma2/float(N);
209       }
210       __syncthreads();
211       return WelfordDataLN{meansigmabuf[0], meansigmabuf[1],0.f};
212 
213     } else {
214       return WelfordDataLN{WARP_SHFL(wd.mean,0), WARP_SHFL(wd.sigma2,0)/float(N), 0.f};
215     }
216 }
217 
218 
219 template <typename T, typename T_ACC,
220 typename std::enable_if<!std::is_same<T, double>::value, int>::type = 0>
vectorized_layer_norm_kernel_impl(const int N,T_ACC eps,const T * __restrict__ X,const T * gamma,const T * beta,T_ACC * mean,T_ACC * rstd,T * Y)221 __device__ __inline__ void vectorized_layer_norm_kernel_impl(
222   const int N,
223   T_ACC eps,
224   const  T* __restrict__ X,
225   const  T* gamma,
226   const  T* beta,
227   T_ACC* mean,
228   T_ACC* rstd,
229   T* Y){
230     extern __shared__ float s_data[]; //if we made smem WelfordDataLN type, there would be bank conflicts,
231     //as one thread would have to write 3 consecutive floats
232     auto i1 = blockIdx.x;
233     const T * block_row = X + i1 * N;
234     WelfordDataLN wd = compute_stats(block_row, N, s_data);
235 
236     using vec_t = aligned_vector<T, vec_size>;
237     const vec_t * X_vec = reinterpret_cast<const vec_t*>(block_row);
238     const vec_t * gamma_vec = (gamma != nullptr) ? reinterpret_cast<const vec_t*>(gamma) : nullptr;
239     const vec_t * beta_vec = (beta != nullptr) ? reinterpret_cast<const vec_t*>(beta) : nullptr;
240     vec_t * Y_vec = reinterpret_cast<vec_t*>(Y + i1 * N);
241 
242     const int numx = blockDim.x * blockDim.y;
243     const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
244     const int n_vec_to_read = N/vec_size;
245 
246     T_ACC rstd_val = c10::cuda::compat::rsqrt(wd.sigma2 + eps);
247 
248     // No tail, N is guaranteed to be multiple of vec size
249     for (int i = thrx; i < n_vec_to_read; i += numx) {
250       vec_t data = X_vec[i];
251       vec_t out;
252 
253       // Computation is performed in T_ACC, X is cast to T_ACC and result is implicitly cast to T
254       if (gamma_vec != nullptr && beta_vec != nullptr) {
255         #pragma unroll
256         for (int ii=0; ii < vec_size; ii++){
257           out.val[ii] = static_cast<T_ACC>(gamma_vec[i].val[ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean))
258             + static_cast<T_ACC>(beta_vec[i].val[ii]);
259         }
260       } else if (gamma_vec != nullptr) {
261         #pragma unroll
262         for (int ii=0; ii < vec_size; ii++){
263           out.val[ii] = static_cast<T_ACC>(gamma_vec[i].val[ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean));
264         }
265       } else if (beta_vec != nullptr) {
266         #pragma unroll
267         for (int ii=0; ii < vec_size; ii++){
268           out.val[ii] = (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean)) + static_cast<T_ACC>(beta_vec[i].val[ii]);
269         }
270       } else {
271         #pragma unroll
272         for (int ii=0; ii < vec_size; ii++){
273           out.val[ii] = rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean);
274         }
275       }
276       Y_vec[i] = out;
277     }
278     if (thrx == 0) {
279       mean[i1] = wd.mean;
280       rstd[i1] = rstd_val;
281     }
282 }
283 
284 template <typename T, typename T_ACC,
285 typename std::enable_if<std::is_same<T, double>::value, int>::type = 0>
vectorized_layer_norm_kernel_impl(const int,T_ACC,const T * __restrict__,const T *,const T *,T_ACC *,T_ACC *,T *)286 __device__ __inline__ void vectorized_layer_norm_kernel_impl(
287   const int /*N*/,
288   T_ACC /*eps*/,
289   const  T* __restrict__ /*X*/,
290   const  T* /*gamma*/,
291   const  T* /*beta*/,
292   T_ACC* /*mean*/,
293   T_ACC* /*rstd*/,
294   T* /*Y*/){
295     CUDA_KERNEL_ASSERT(false && "doesn't work with double");
296   }
297 
298 //to avoid windows SFINAE errors
299 template <typename T, typename T_ACC>
vectorized_layer_norm_kernel(const int N,T_ACC eps,const T * __restrict__ X,const T * gamma,const T * beta,T_ACC * mean,T_ACC * rstd,T * Y)300 __global__ void vectorized_layer_norm_kernel(
301   const int N,
302   T_ACC eps,
303   const  T* __restrict__ X,
304   const  T* gamma,
305   const  T* beta,
306   T_ACC* mean,
307   T_ACC* rstd,
308   T* Y){
309     vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y);
310   }
311 
312 
313 template<typename T, typename T_ACC>
compute_gI(const T * __restrict__ dY,const T * __restrict__ X,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * __restrict__ gamma,T * dX,const int N,T_ACC * buf)314 __device__ __inline__ void compute_gI(
315   const T* __restrict__ dY,
316   const T* __restrict__ X,
317   const T_ACC* __restrict__ mean,
318   const T_ACC* __restrict__ rstd,
319   const T* __restrict__ gamma,
320   T* dX,
321   const int N,
322   T_ACC * buf){
323     const auto i1 = blockIdx.x;
324     const T_ACC mean_val = mean[i1];
325     const T_ACC rstd_val = rstd[i1];
326     T_ACC stats_x1{0}, stats_x2{0};
327     constexpr int unroll = 4;
328     auto l = unroll * threadIdx.x;
329     const T * X_i = X + i1 * N;
330     const T * dY_i = dY + i1 * N;
331     T * dX_i = dX + i1 * N;
332     //vectorized reads don't improve perf, so use regular unrolling
333 
334     for (; l+unroll - 1 < N; l += blockDim.x * unroll){
335       #pragma unroll
336       for (int k=0; k< unroll; k++){
337           const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l+k]) : T_ACC(1);
338           const auto c_h = static_cast<T_ACC>(X_i[l+k]);
339           const auto c_loss = static_cast<T_ACC>(dY_i[l+k]);
340           stats_x1 += c_loss * gamma_val;
341           stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
342       }
343     }
344     for (;  l < N; l ++) {
345           const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
346           const auto c_h = static_cast<T_ACC>(X_i[l]);
347           const auto c_loss = static_cast<T_ACC>(dY_i[l]);
348           stats_x1 += c_loss * gamma_val;
349           stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
350     }
351 
352     stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf);
353     stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf);
354     if (threadIdx.x == 0) {
355       buf[0] = stats_x1;
356       buf[1] = stats_x2;
357     }
358     __syncthreads();
359     stats_x1 = buf[0];
360     stats_x2 = buf[1];
361     T_ACC fH = N;
362     T_ACC term1 = (T_ACC(1) / fH) * rstd_val;
363 
364     for (int l = threadIdx.x; l < N; l += blockDim.x){
365         const auto x = X_i[l];
366         const auto dy = dY_i[l];
367         const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
368 
369         T_ACC f_grad_input = fH * gamma_val * dy;
370         f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
371         f_grad_input -= stats_x1;
372         f_grad_input *= term1;
373         dX_i[l] = f_grad_input;
374     }
375   }
376 
377 
378 template<typename T, typename T_ACC>
layer_norm_grad_input_kernel(const T * __restrict__ dY,const T * __restrict__ X,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * __restrict__ gamma,T * dX,const int N)379 __global__ void layer_norm_grad_input_kernel(
380   const T* __restrict__ dY,
381   const T* __restrict__ X,
382   const T_ACC* __restrict__ mean,
383   const T_ACC* __restrict__ rstd,
384   const T* __restrict__ gamma,
385   T*  dX,
386   const int N){
387     alignas(sizeof(double)) extern __shared__ char s_data1[];
388     T_ACC * buf = reinterpret_cast<T_ACC*>(&s_data1);
389 
390     compute_gI(dY, X, mean, rstd, gamma, dX, N, buf);
391   }
392 
393 
394 // This implementation gets called when input buffers (dY, X, gamma and dX) are aligned
395 // to vec_size * sizeof(T). Compared to the unvectorized implementation, it is about 10%
396 // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M).
397 // There are no noticeable regressions on the rest of the sizes.
398 
399 template<typename T, typename T_ACC>
layer_norm_grad_input_kernel_vectorized(const T * __restrict__ dY,const T * __restrict__ X,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * __restrict__ gamma,T * dX,const int N)400 __global__ void layer_norm_grad_input_kernel_vectorized(
401   const T* __restrict__ dY,
402   const T* __restrict__ X,
403   const T_ACC* __restrict__ mean,
404   const T_ACC* __restrict__ rstd,
405   const T* __restrict__ gamma,
406   T* dX,
407   const int N) {
408   alignas(sizeof(double)) extern __shared__ char shared_data[];
409   T_ACC* reduce_buf = reinterpret_cast<T_ACC*>(&shared_data);
410 
411   const auto bIdx = blockIdx.x;
412   const T_ACC mean_val = mean[bIdx];
413   const T_ACC rstd_val = rstd[bIdx];
414   const T* X_i = X + bIdx * N;
415   const T* dY_i = dY + bIdx * N;
416   T* dX_i = dX + bIdx * N;
417 
418   using vec_t = aligned_vector<T, vec_size>;
419   const vec_t* const X_i_vec_ptr = reinterpret_cast<const vec_t*>(X_i);
420   const vec_t* const dY_i_vec_ptr = reinterpret_cast<const vec_t*>(dY_i);
421   const vec_t* const gamma_vec_ptr = (gamma != nullptr) ? reinterpret_cast<const vec_t*>(gamma) : nullptr;
422   vec_t* const dX_i_vec = reinterpret_cast<vec_t*>(dX_i);
423 
424   vec_t X_i_vec_reg, dY_i_vec_reg, gamma_vec_reg, dX_i_vec_reg;
425   for (int k = 0; k < vec_size; ++k) {
426     gamma_vec_reg.val[k] = T(1);
427   }
428 
429   T_ACC stats_x1{0}, stats_x2{0};
430   unsigned int l = threadIdx.x * vec_size;
431   for (; l + vec_size - 1 < N; l += blockDim.x * vec_size) {
432     unsigned int vec_idx = l / vec_size;
433     if (gamma != nullptr) {
434       gamma_vec_reg = gamma_vec_ptr[vec_idx];
435     }
436 
437     X_i_vec_reg = X_i_vec_ptr[vec_idx];
438     dY_i_vec_reg = dY_i_vec_ptr[vec_idx];
439 
440     for (int k = 0; k < vec_size; ++k) {
441       const auto gamma_val = static_cast<T_ACC>(gamma_vec_reg.val[k]);
442       const auto c_h = static_cast<T_ACC>(X_i_vec_reg.val[k]);
443       const auto c_loss = static_cast<T_ACC>(dY_i_vec_reg.val[k]);
444       stats_x1 += c_loss * gamma_val;
445       stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
446     }
447   }
448 
449   // Tail Loop
450   for (; l < N; l++) {
451     const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
452     const auto c_h = static_cast<T_ACC>(X_i[l]);
453     const auto c_loss = static_cast<T_ACC>(dY_i[l]);
454     stats_x1 += c_loss * gamma_val;
455     stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
456   }
457 
458   // Reduction in Shared Memory
459   stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf);
460   stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf);
461   if (threadIdx.x == 0) {
462     reduce_buf[0] = stats_x1;
463     reduce_buf[1] = stats_x2;
464   }
465   __syncthreads();
466   stats_x1 = reduce_buf[0];
467   stats_x2 = reduce_buf[1];
468 
469   T_ACC fH = N;
470   T_ACC term1 = (T_ACC(1) / fH) * rstd_val;
471 
472   l = threadIdx.x * vec_size;
473   for (; l + vec_size - 1 < N; l += blockDim.x * vec_size) {
474     unsigned int vec_idx = l / vec_size;
475     if (gamma != nullptr) {
476       gamma_vec_reg = gamma_vec_ptr[vec_idx];
477     }
478 
479     X_i_vec_reg = X_i_vec_ptr[vec_idx];
480     dY_i_vec_reg = dY_i_vec_ptr[vec_idx];
481 
482     for (int k = 0; k < vec_size; ++k) {
483       const auto gamma_val = static_cast<T_ACC>(gamma_vec_reg.val[k]);
484       const auto x = static_cast<T_ACC>(X_i_vec_reg.val[k]);
485       const auto dy = static_cast<T_ACC>(dY_i_vec_reg.val[k]);
486 
487       T_ACC f_grad_input = fH * gamma_val * dy;
488       f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
489       f_grad_input -= stats_x1;
490       f_grad_input *= term1;
491       dX_i_vec_reg.val[k] = f_grad_input;
492     }
493 
494     dX_i_vec[vec_idx] = dX_i_vec_reg;
495   }
496 
497   // Tail Loop
498   for (; l < N; l += blockDim.x) {
499     const auto x = X_i[l];
500     const auto dy = dY_i[l];
501     const auto gamma_val = (gamma != nullptr) ? static_cast<T_ACC>(gamma[l]) : T_ACC(1);
502 
503     T_ACC f_grad_input = fH * gamma_val * dy;
504     f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
505     f_grad_input -= stats_x1;
506     f_grad_input *= term1;
507     dX_i[l] = f_grad_input;
508   }
509 }
510 
511 
512 template <typename T, typename T_ACC>
GammaBetaBackwardSimpleCUDAKernel(int64_t M,int64_t N,const T * dY,const T * X,const T_ACC * mean,const T_ACC * rstd,T * dg,T * db)513 __global__ void GammaBetaBackwardSimpleCUDAKernel(
514     int64_t M,
515     int64_t N,
516     const T* dY,
517     const T* X,
518     const T_ACC* mean,
519     const T_ACC* rstd,
520     T* dg,
521     T* db) {
522   const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
523   if (j < N) {
524     T_ACC sum1 = 0;
525     T_ACC sum2 = 0;
526     for (int64_t i = 0; i < M; ++i) {
527       const int64_t index = i * N + j;
528       sum1 += dg == nullptr ? T_ACC(0)
529                             : static_cast<T_ACC>(dY[index]) *
530               (static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) *
531               static_cast<T_ACC>(rstd[i]);
532       sum2 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index]);
533     }
534     if (dg != nullptr) {
535       dg[j] = sum1;
536     }
537     if (db != nullptr) {
538       db[j] = sum2;
539     }
540   }
541 }
542 
543 // This implementation gets called if M and N divide with 32. This case should
544 // be the most common. We can then make better use of warp level intrinsics
545 // to improve performance.
546 
547 template <typename T, typename T_ACC>
GammaBetaBackwardCUDAKernel_32x32(int64_t M,int64_t N,const T * dY,const T * X,const T_ACC * mean,const T_ACC * rstd,T * dg,T * db)548 __global__ void GammaBetaBackwardCUDAKernel_32x32(
549     int64_t M,
550     int64_t N,
551     const T* dY,
552     const T* X,
553     const T_ACC* mean,
554     const T_ACC* rstd,
555     T* dg,
556     T* db) {
557   alignas(sizeof(double)) extern __shared__ char s_data1[];
558   T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
559   T_ACC* s_dg;
560   T_ACC* s_db;
561 
562   T_ACC dg_sum = 0;
563   T_ACC db_sum = 0;
564 
565   const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
566 
567   if (j < N) {
568     constexpr int unroll_factor = 8;
569     int laneId = threadIdx.x & (C10_WARP_SIZE - 1);
570 
571     T_ACC mean_reg, mean_reg_tmp;
572     T_ACC rstd_reg, rstd_reg_tmp;
573     T dY_reg;
574     T X_reg;
575 
576     // Main loop
577     int bcounter;
578     for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
579          bcounter++) {
580       int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
581 
582       if (laneId < unroll_factor) {
583         mean_reg_tmp = mean[offset + laneId];
584         rstd_reg_tmp = rstd[offset + laneId];
585       }
586       WARP_SYNC();
587 
588       #pragma unroll
589       for (int ii = 0; ii < unroll_factor; ++ii) {
590         dY_reg = dY[(offset + ii) * N + j];
591         X_reg = X[(offset + ii) * N + j];
592         mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
593         rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
594         dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
595         db_sum += dY_reg;
596       }
597     }
598 
599     // Remainder loop
600     int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
601     for (int ii = 0; ii < unroll_factor; ii++) {
602       if ((offset + ii) < M) {
603         mean_reg = mean[offset + ii];
604         rstd_reg = rstd[offset + ii];
605         dY_reg = dY[(offset + ii) * N + j];
606         X_reg = X[(offset + ii) * N + j];
607         dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
608         db_sum += dY_reg;
609       }
610     }
611 
612     // This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and
613     // gets called when M; N divide by 32. We can use warp shuffles
614     // for the final reduction step. This removes 4 shmem loads and
615     // stores with their corresponding __syncthreads()
616 
617     // This greatly reduces bank conflicts at the expense of a little
618     // extra shared memory. It does not impact occupancy
619     int padded_bx = (1 + blockDim.x);
620 
621     s_dg = s_data_typed;
622     s_db = s_data_typed + (padded_bx * blockDim.y);
623     s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
624     s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
625     __syncthreads();
626 
627     // Load transposed so that a warp holds an entire column
628     T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
629     T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
630     for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) {
631       reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
632       reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
633     }
634 
635     if (threadIdx.x == 0) {
636       const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
637       if (dg) {
638         dg[j] = reg_dg;
639       }
640       if (db) {
641         db[j] = reg_db;
642       }
643     }
644   }
645 }
646 
647 template <typename T, typename T_ACC>
GammaBetaBackwardCUDAKernel(int64_t M,int64_t N,const T * dY,const T * X,const T_ACC * mean,const T_ACC * rstd,T * dg,T * db)648 __global__ void GammaBetaBackwardCUDAKernel(
649     int64_t M,
650     int64_t N,
651     const T* dY,
652     const T* X,
653     const T_ACC* mean,
654     const T_ACC* rstd,
655     T* dg,
656     T* db) {
657   alignas(sizeof(double)) extern __shared__ char s_data1[];
658   T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
659   T_ACC* s_dg;
660   T_ACC* s_db;
661 
662   const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
663 
664   T_ACC dg_sum = 0;
665   T_ACC db_sum = 0;
666 
667   if (j < N) {
668     constexpr int unroll_factor = 8;
669 
670     T_ACC mean_reg;
671     T_ACC rstd_reg;
672     T dY_reg;
673     T X_reg;
674 
675     // Main Loop
676     int bcounter;
677     for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
678       int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
679 
680       #pragma unroll
681       for (int ii = 0; ii < unroll_factor; ++ii) {
682         dY_reg = dY[(offset + ii) * N + j];
683         X_reg = X[(offset + ii) * N + j];
684         mean_reg = mean[offset + ii];
685         rstd_reg = rstd[offset + ii];
686         dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
687         db_sum += dY_reg;
688       }
689     }
690 
691     // Remainder loop
692     int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
693     for (int ii = 0; ii < unroll_factor; ii++ ){
694       if ((offset + ii) < M) {
695         dY_reg = dY[(offset + ii) * N + j ];
696         X_reg = X[(offset + ii) * N + j];
697         mean_reg = mean[offset + ii];
698         rstd_reg = rstd[offset + ii];
699         dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
700         db_sum += dY_reg;
701       }
702     }
703 
704     // Do the final reduction in shared memory
705     s_dg = s_data_typed;
706     s_db = s_data_typed + blockDim.x * blockDim.y;
707     s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
708     s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
709     __syncthreads();
710 
711     for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
712       if (threadIdx.y < offset) {
713         s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
714             s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
715         s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
716             s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
717         }
718       __syncthreads();
719     }
720 
721     if (threadIdx.y == 0) {
722       if (dg) {
723         dg[j] = s_dg[threadIdx.x];
724       }
725       if (db) {
726         db[j] = s_db[threadIdx.x];
727       }
728     }
729   }
730 }
731 
732 template <typename T, typename T_ACC>
launch_vectorized_layer_norm_kernel(int N,int64_t M,T_ACC eps,const T * X_data,const T * gamma_data,const T * beta_data,T * Y_data,T_ACC * mean_data,T_ACC * rstd_data)733 void launch_vectorized_layer_norm_kernel(
734   int N,
735   int64_t M,
736   T_ACC eps,
737   const T* X_data,
738   const T* gamma_data,
739   const T* beta_data,
740   T* Y_data,
741   T_ACC* mean_data,
742   T_ACC* rstd_data
743 ) {
744     //constexpr int alignment = 16; //currently unused to make sure float and half results are bw accurate
745     auto stream = at::cuda::getCurrentCUDAStream().stream();
746     const int warp_size = at::cuda::warp_size();
747     const dim3 threads(warp_size, num_threads() / warp_size, 1);
748     const dim3 blocks(M);
749     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1);
750     int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0;
751     vectorized_layer_norm_kernel<<<blocks, threads, nshared, stream>>>(N, eps, X_data,
752     gamma_data, beta_data, mean_data, rstd_data, Y_data);
753     C10_CUDA_KERNEL_LAUNCH_CHECK();
754 }
755 
756 template <typename T, typename T_ACC>
LayerNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,T_ACC eps,Tensor * Y,Tensor * mean,Tensor * rstd)757 void LayerNormKernelImplInternal(
758     const Tensor& X,
759     const Tensor& gamma,
760     const Tensor& beta,
761     int64_t M,
762     int64_t N,
763     T_ACC eps,
764     Tensor* Y,
765     Tensor* mean,
766     Tensor* rstd) {
767   // assumes input, gamma and beta are of proper shape, this was checked in _check_layer_norm_inputs
768   // assumes all tensors are contiguous
769   TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \
770   file a support request to support bigger batches");
771   const T* X_data = X.const_data_ptr<T>();
772   const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
773   const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
774   T* Y_data = Y->data_ptr<T>();
775   T_ACC* mean_data = mean->data_ptr<T_ACC>();
776   T_ACC* rstd_data = rstd->data_ptr<T_ACC>();
777 
778   // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count),
779   // N is multiple of vec_size (so that all rows are aligned if tensor is aligned)
780   constexpr int num_vec_elems = vec_size;
781   constexpr int alignment = num_vec_elems * sizeof(T);
782   bool can_vec_X = can_vectorize(X_data, alignment);
783   bool can_vec_Y = can_vectorize(Y_data, alignment);
784   bool can_vec_gamma = gamma.defined() ? can_vectorize(gamma_data, alignment) : true;
785   bool can_vec_beta = beta.defined() ? can_vectorize(beta_data, alignment) : true;
786 
787   if ((std::is_same<T, float>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) &&
788   N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && N % num_vec_elems == 0 &&
789   can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) {
790     launch_vectorized_layer_norm_kernel(static_cast<int>(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data);
791   } else {
792   cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
793   RowwiseMomentsCUDAKernel<T, T_ACC>
794       <<<M, cuda_utils::kCUDABlockReduceNumThreads, 0, cuda_stream>>>(
795           N, eps, X_data, mean_data, rstd_data);
796   C10_CUDA_KERNEL_LAUNCH_CHECK();
797   LayerNormForwardCUDAKernel<T, T_ACC><<<M, kCUDANumThreads, 0, cuda_stream>>>(
798       N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data);
799   C10_CUDA_KERNEL_LAUNCH_CHECK();
800   }
801 }
802 
LayerNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,double eps,Tensor * Y,Tensor * mean,Tensor * rstd)803 void LayerNormKernelImpl(
804     const Tensor& X,
805     const Tensor& gamma,
806     const Tensor& beta,
807     int64_t M,
808     int64_t N,
809     double eps,
810     Tensor* Y,
811     Tensor* mean,
812     Tensor* rstd) {
813   AT_DISPATCH_FLOATING_TYPES_AND2(
814       at::ScalarType::Half,
815       at::ScalarType::BFloat16,
816       X.scalar_type(),
817       "LayerNormKernelImpl",
818       [&]() {
819         using acc_t = acc_type<scalar_t, true>;
820         LayerNormKernelImplInternal<scalar_t, acc_t>(
821             X, gamma, beta, M, N, static_cast<acc_t>(eps), Y, mean, rstd);
822       });
823 }
824 
825 template<typename T, typename T_ACC> __device__
cuLoadWriteStridedInputs(const int i1_block,const int thr_load_row_off,const int thr_load_col_off,const int i2_off,const int row_stride,T_ACC * warp_buf1,T_ACC * warp_buf2,const T * input,const T * dout,const int i1_end,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd)826 void cuLoadWriteStridedInputs(
827     const int i1_block,
828     const int thr_load_row_off,
829     const int thr_load_col_off,
830     const int i2_off,
831     const int row_stride,
832     T_ACC* warp_buf1,
833     T_ACC* warp_buf2,
834     const T* input,
835     const T* dout,
836     const int i1_end,
837     const int64_t N,
838     const T_ACC* __restrict__ mean,
839     const T_ACC* __restrict__ rstd)
840 {
841   int i1 = i1_block+thr_load_row_off;
842   if (i1 < i1_end) {
843     T curr_mean = mean[i1];
844     T curr_rstd = rstd[i1];
845     for (int k = 0;  k < blockDim.y;  ++k) {
846       int i2 = i2_off + k;
847       int load_idx = i1*N+i2;
848       int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
849       if (i2<N) {
850         T curr_input = static_cast<T>(input[load_idx]);
851         T curr_dout = static_cast<T>(dout[load_idx]);
852         warp_buf1[write_idx] = curr_dout;
853         warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_rstd;
854       } else {
855         warp_buf1[write_idx] = T(0);
856         warp_buf2[write_idx] = T(0);
857       }
858     }
859   } else {
860     for (int k = 0;  k < blockDim.y;  ++k) {
861       int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
862       warp_buf1[write_idx] = T(0);
863       warp_buf2[write_idx] = T(0);
864     }
865   }
866 }
867 
868 template<typename T, typename T_ACC> __device__
cuLoadAddStridedInputs(const int i1_block,const int thr_load_row_off,const int thr_load_col_off,const int i2_off,const int row_stride,T_ACC * warp_buf1,T_ACC * warp_buf2,const T * input,const T * dout,const int i1_end,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd)869 void cuLoadAddStridedInputs(
870     const int i1_block,
871     const int thr_load_row_off,
872     const int thr_load_col_off,
873     const int i2_off,
874     const int row_stride,
875     T_ACC* warp_buf1,
876     T_ACC* warp_buf2,
877     const T* input,
878     const T* dout,
879     const int i1_end,
880     const int64_t N,
881     const T_ACC* __restrict__ mean,
882     const T_ACC* __restrict__ rstd)
883 {
884   int i1 = i1_block+thr_load_row_off;
885   if (i1 < i1_end) {
886     T_ACC curr_mean = mean[i1];
887     T_ACC curr_rstd = rstd[i1];
888     for (int k = 0;  k < blockDim.y;  ++k) {
889       int i2 = i2_off + k;
890       int load_idx = i1*N+i2;
891       int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
892       if (i2<N) {
893         T_ACC curr_input = static_cast<T_ACC>(input[load_idx]);
894         T_ACC curr_dout = static_cast<T_ACC>(dout[load_idx]);
895         warp_buf1[write_idx] += curr_dout;
896         warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_rstd;
897       }
898     }
899   }
900 }
901 
902 template<typename T, typename T_ACC> __global__
cuComputePartGradGammaBeta(const T * __restrict__ dout,const T * __restrict__ input,const int64_t M,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,T_ACC * part_grad_gamma,T_ACC * part_grad_beta)903 void cuComputePartGradGammaBeta(
904     const T* __restrict__ dout,
905     const T* __restrict__ input,
906     const int64_t M,
907     const int64_t N,
908     const T_ACC* __restrict__ mean,
909     const T_ACC* __restrict__ rstd,
910     T_ACC* part_grad_gamma,
911     T_ACC* part_grad_beta)
912 {
913     const int numsegs_M = (M+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
914     const int segs_per_block = (numsegs_M + gridDim.y - 1) / gridDim.y;
915     const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
916     const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
917     const int i1_end = i1_beg_plus_one < M ? i1_beg_plus_one : M;
918     const int row_stride = blockDim.x+1;
919     const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
920     const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
921     const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
922     alignas(sizeof(double)) extern __shared__ char shared[];
923     T_ACC * buf = reinterpret_cast<T_ACC*>(&shared); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
924     T_ACC* warp_buf1 = (T_ACC*)buf;
925     T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
926     // compute partial sums from strided inputs
927     // do this to increase number of loads in flight
928     cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
929     for (int i1_block = i1_beg+blockDim.y*blockDim.y;  i1_block < i1_end;  i1_block+=blockDim.y*blockDim.y) {
930       cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
931     }
932     __syncthreads();
933     // inter-warp reductions
934     // sum within each warp
935     T_ACC acc1 = T_ACC(0);
936     T_ACC acc2 = T_ACC(0);
937     for (int k = 0;  k < blockDim.y;  ++k) {
938       int row1 = threadIdx.y + k*blockDim.y;
939       int idx1 = row1*row_stride + threadIdx.x;
940       acc1 += warp_buf1[idx1];
941       acc2 += warp_buf2[idx1];
942     }
943     warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
944     warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
945     __syncthreads();
946     // sum all warps
947     for (int offset = blockDim.y/2;  offset > 1;  offset /= 2) {
948       if (threadIdx.y < offset) {
949         int row1 = threadIdx.y;
950         int row2 = threadIdx.y + offset;
951         int idx1 = row1*row_stride + threadIdx.x;
952         int idx2 = row2*row_stride + threadIdx.x;
953         warp_buf1[idx1] += warp_buf1[idx2];
954         warp_buf2[idx1] += warp_buf2[idx2];
955       }
956       __syncthreads();
957     }
958     int i2 = blockIdx.x * blockDim.x + threadIdx.x;
959     if (threadIdx.y == 0 && i2 < N) {
960       int row1 = threadIdx.y;
961       int row2 = threadIdx.y + 1;
962       int idx1 = row1*row_stride + threadIdx.x;
963       int idx2 = row2*row_stride + threadIdx.x;
964       part_grad_beta[blockIdx.y*N+i2] = warp_buf1[idx1] + warp_buf1[idx2];
965       part_grad_gamma[blockIdx.y*N+i2] = warp_buf2[idx1] + warp_buf2[idx2];
966     }
967 }
968 
969 template<typename T, typename T_ACC> __global__
cuComputeGradGammaBeta(const T_ACC * part_grad_gamma,const T_ACC * part_grad_beta,const int part_size,const int64_t M,const int64_t N,T * grad_gamma,T * grad_beta)970 void cuComputeGradGammaBeta(
971     const T_ACC* part_grad_gamma,
972     const T_ACC* part_grad_beta,
973     const int part_size,
974     const int64_t M,
975     const int64_t N,
976     T* grad_gamma,
977     T* grad_beta)
978 {
979     // sum partial gradients for gamma and beta
980     alignas(sizeof(double)) extern __shared__ char shared[];
981     T_ACC * buf = reinterpret_cast<T_ACC*>(&shared);
982     int i2 = blockIdx.x * blockDim.x + threadIdx.x;
983 
984     // each warp does sequential reductions until reduced part_size is num_warps
985     int num_warp_reductions = part_size / blockDim.y;
986     T_ACC sum_gamma = T_ACC(0);
987     T_ACC sum_beta = T_ACC(0);
988     const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * N + i2;
989     const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * N + i2;
990 
991     if (i2 < N) {
992         for (int warp_offset = 0;  warp_offset < num_warp_reductions;  ++warp_offset) {
993           sum_gamma += part_grad_gamma_ptr[warp_offset*N];
994           sum_beta += part_grad_beta_ptr[warp_offset*N];
995         }
996     }
997 
998     // inter-warp reductions
999     const int nbsize3 = blockDim.x * blockDim.y / 2;
1000     for (int offset = blockDim.y/2;  offset >= 1;  offset /= 2) {
1001       // top half write to shared memory
1002       if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
1003         const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
1004         buf[write_idx] = sum_gamma;
1005         buf[write_idx+nbsize3] = sum_beta;
1006       }
1007       __syncthreads();
1008       // bottom half sums
1009       if (threadIdx.y < offset) {
1010         const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
1011         sum_gamma += buf[read_idx];
1012         sum_beta += buf[read_idx+nbsize3];
1013       }
1014       __syncthreads();
1015     }
1016 
1017     // write out fully summed gradients
1018     if (threadIdx.y == 0 && i2 < N) {
1019       if (grad_gamma) {
1020           grad_gamma[i2] = sum_gamma;
1021       }
1022       if (grad_beta) {
1023           grad_beta[i2] = sum_beta;
1024       }
1025     }
1026 }
1027 
1028 template<typename T, typename T_ACC> __global__
cuComputeGradInput(const T * __restrict__ dout,const T * __restrict__ input,const int64_t M,const int64_t N,const T_ACC * __restrict__ mean,const T_ACC * __restrict__ rstd,const T * gamma,T * grad_input)1029 void cuComputeGradInput(
1030     const T* __restrict__ dout,
1031     const T* __restrict__ input,
1032     const int64_t M,
1033     const int64_t N,
1034     const T_ACC* __restrict__ mean,
1035     const T_ACC* __restrict__ rstd,
1036     const T* gamma,
1037     T* grad_input)
1038 {
1039   for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) {
1040     T_ACC sum_loss1 = T_ACC(0);
1041     T_ACC sum_loss2 = T_ACC(0);
1042     T_ACC c_mean = mean[i1];
1043     const T_ACC c_rstd = rstd[i1];
1044     const T* k_input = input + i1*N;
1045     const T* k_dout = dout + i1*N;
1046     const int numx = blockDim.x * blockDim.y;
1047     const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
1048     if (gamma != NULL) {
1049       // Optimization for ROCm MI100
1050       for( int l = 0; l < N ; l += numx) {
1051         int idx = l + thrx;
1052         const T_ACC gamma_idx = static_cast<T_ACC>((idx<N) ? gamma[idx] : T(0));
1053         const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
1054         const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
1055         sum_loss1 += c_loss * gamma_idx;
1056         sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_rstd;
1057       }
1058     } else {
1059       for( int l = 0; l < N ; l += numx) {
1060         int idx = l + thrx;
1061         const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
1062         const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
1063         sum_loss1 += c_loss;
1064         sum_loss2 += c_loss * (c_h - c_mean) * c_rstd;
1065       }
1066     }
1067     // intra-warp reductions
1068     for (int mask = blockDim.x/2;  mask > 0;  mask /= 2) {
1069       sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
1070       sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
1071     }
1072     // inter-warp reductions
1073     if (blockDim.y > 1) {
1074       alignas(sizeof(double)) extern __shared__ char shared[];
1075       T_ACC * buf = reinterpret_cast<T_ACC*>(&shared);
1076       for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
1077         // upper half of warps write to shared
1078         if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
1079           const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
1080           buf[2*wrt_i] = sum_loss1;
1081           buf[2*wrt_i+1] = sum_loss2;
1082         }
1083         __syncthreads();
1084         // lower half merges
1085         if (threadIdx.y < offset) {
1086           const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
1087           sum_loss1 += buf[2*read_i];
1088           sum_loss2 += buf[2*read_i+1];
1089         }
1090         __syncthreads();
1091       }
1092       if (threadIdx.y == 0) {
1093         buf[2*threadIdx.x] = sum_loss1;
1094         buf[2*threadIdx.x+1] = sum_loss2;
1095       }
1096       __syncthreads();
1097       if (threadIdx.y !=0) {
1098         sum_loss1 = buf[2*threadIdx.x];
1099         sum_loss2 = buf[2*threadIdx.x+1];
1100       }
1101     }
1102     // all threads now have the two sums over l
1103     T_ACC fH = (T_ACC)N;
1104     T_ACC term1 = (T_ACC(1) / fH) * c_rstd;
1105     T* k_grad_input = grad_input + i1*N;
1106     if (gamma != NULL) {
1107       for (int l = thrx;  l < N;  l+=numx) {
1108         const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
1109         const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
1110         T_ACC f_grad_input = fH * c_loss * gamma[l];
1111         f_grad_input -= sum_loss1;
1112         f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
1113         f_grad_input *= term1;
1114         k_grad_input[l] = static_cast<T>(f_grad_input);
1115       }
1116     } else {
1117       for (int l = thrx;  l < N;  l+=numx) {
1118         const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
1119         const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
1120         T_ACC f_grad_input = fH * c_loss;
1121         f_grad_input -= sum_loss1;
1122         f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
1123         f_grad_input *= term1;
1124         k_grad_input[l] = static_cast<T>(f_grad_input);
1125       }
1126     }
1127     // prevent race where buf is written again before reads are done
1128     __syncthreads();
1129   }
1130 }
1131 
1132 template <typename T>
LayerNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)1133 void LayerNormBackwardKernelImplInternal(
1134     const Tensor& dY,
1135     const Tensor& X,
1136     const Tensor& mean,
1137     const Tensor& rstd,
1138     const Tensor& gamma,
1139     int64_t M,
1140     int64_t N,
1141     Tensor* dX,
1142     Tensor* dgamma,
1143     Tensor* dbeta) {
1144   using T_ACC = acc_type<T, true>;
1145   TORCH_CHECK(dY.numel() == M * N);
1146   TORCH_CHECK(mean.numel() == M);
1147   TORCH_CHECK(rstd.numel() == M);
1148   TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \
1149   file a support request to support bigger batches");
1150   TORCH_CHECK(N <= std::numeric_limits<int>::max(), "Normalized shape should have less than INT_MAX elements, \
1151   file a support request to support bigger normalized shapes");
1152   const T* dY_data = dY.template const_data_ptr<T>();
1153   const T* X_data = X.template const_data_ptr<T>();
1154   const T_ACC* mean_data = mean.template const_data_ptr<T_ACC>();
1155   const T_ACC* rstd_data = rstd.template const_data_ptr<T_ACC>();
1156   const T* gamma_data =
1157       gamma.defined() ? gamma.template const_data_ptr<T>() : nullptr;
1158   T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
1159   cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
1160   const int warp_size = at::cuda::warp_size();
1161   if (dX_data != nullptr) {
1162 #ifdef USE_ROCM
1163     if (M >= 32768) {
1164       const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
1165       const dim3 blocks1(1, std::min((uint64_t)M, maxGridY), 1);
1166       dim3 threads1(warp_size, 4, 1);
1167       threads1.y = 2; // Optimization for ROCm
1168       int nshared =
1169               threads1.y > 1 ?
1170               threads1.y*threads1.x*sizeof(T_ACC) :
1171               0;
1172       cuComputeGradInput<<<blocks1, threads1, nshared, cuda_stream>>>(
1173               dY_data,
1174               X_data,
1175               M, N,
1176               mean_data,
1177               rstd_data,
1178               gamma_data,
1179               dX_data);
1180         C10_CUDA_KERNEL_LAUNCH_CHECK();
1181     } else {
1182       const dim3 blocks(M);
1183       int nshared = (num_threads()/warp_size) * sizeof(T_ACC);
1184       layer_norm_grad_input_kernel<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
1185       X_data, mean_data, rstd_data, gamma_data, dX_data, N);
1186       C10_CUDA_KERNEL_LAUNCH_CHECK();
1187     }
1188 #else
1189     const dim3 blocks(M);
1190     int nshared = (num_threads() / warp_size) * sizeof(T_ACC);
1191 
1192     bool bVectorSizeMultiple = (N % vec_size == 0);
1193     bool bTargetDataTypes = (std::is_same<T, float>::value || std::is_same<T, at::Half>::value ||
1194       std::is_same<T, at::BFloat16>::value);
1195     const unsigned int alignment = sizeof(T) * vec_size;
1196     bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) &&
1197       can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment);
1198 
1199     if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) {
1200       layer_norm_grad_input_kernel_vectorized<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
1201           X_data, mean_data, rstd_data, gamma_data, dX_data, N);
1202       C10_CUDA_KERNEL_LAUNCH_CHECK();
1203     } else {
1204       layer_norm_grad_input_kernel<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
1205           X_data, mean_data, rstd_data, gamma_data, dX_data, N);
1206       C10_CUDA_KERNEL_LAUNCH_CHECK();
1207     }
1208 #endif
1209   }
1210 
1211   if (dgamma->defined() || dbeta->defined()) {
1212     T* dgamma_data =
1213         dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
1214     T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
1215 
1216     if (M < 128) {
1217       // For small batch size, do colwise reduce directly.
1218       const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
1219       GammaBetaBackwardSimpleCUDAKernel<T, T_ACC>
1220           <<<B, kCUDANumThreads, 0, cuda_stream>>>(
1221               M,
1222               N,
1223               dY_data,
1224               X_data,
1225               mean_data,
1226               rstd_data,
1227               dgamma_data,
1228               dbeta_data);
1229       C10_CUDA_KERNEL_LAUNCH_CHECK();
1230     } else {
1231 #if defined(USE_ROCM)
1232       // For small batch size, do colwise reduce directly.
1233       const int part_size = warp_size;
1234       const dim3 threads2(warp_size, 4, 1);
1235       const dim3 blocks2((N + threads2.x - 1) / threads2.x, part_size, 1);
1236       const int nshared2_a = 2 * sizeof(T_ACC) * threads2.y * threads2.y * (threads2.x + 1);
1237       const int nshared2_b = threads2.x * threads2.y * sizeof(T_ACC);
1238       const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
1239 
1240       const auto part_grad_dtype = at::toAccumulateType(X.scalar_type(), true);
1241       Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype));
1242       Tensor part_grad_beta = at::native::empty_like(part_grad_gamma);
1243 
1244       cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, cuda_stream>>>(
1245                       dY_data,
1246                       X_data,
1247                       M,N,
1248                       mean_data,
1249                       rstd_data,
1250                       part_grad_gamma.template data_ptr<T_ACC>(),
1251                       part_grad_beta.template data_ptr<T_ACC>());
1252       C10_CUDA_KERNEL_LAUNCH_CHECK();
1253 
1254       const dim3 threads3(warp_size, 8, 1); // Optimization for ROCm
1255       const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1);
1256       const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC);
1257 
1258       cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, cuda_stream>>>(
1259                       part_grad_gamma.template data_ptr<T_ACC>(),
1260                       part_grad_beta.template data_ptr<T_ACC>(),
1261                       part_size,
1262                       M,N,
1263                       dgamma_data,
1264                       dbeta_data);
1265       C10_CUDA_KERNEL_LAUNCH_CHECK();
1266 #else
1267       if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
1268         // This implementation relies on warp primitives and requires that M and N divide
1269         // exactly to warp size.
1270         dim3 threads{kWarpSize, kWarpSize};
1271         int blocks = (N + threads.x - 1) / threads.x;
1272 
1273         // If M and N divide by warp_size, we can use warp shuffles for the final reduction.
1274         // That requires transposing values in shared memory, so we apply a padding to
1275         // reduce bank conflicts.
1276 
1277         size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
1278         GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
1279             <<<blocks, threads, shmem_sz, cuda_stream>>>(
1280                 M,
1281                 N,
1282                 dY_data,
1283                 X_data,
1284                 mean_data,
1285                 rstd_data,
1286                 dgamma_data,
1287                 dbeta_data);
1288           C10_CUDA_KERNEL_LAUNCH_CHECK();
1289       } else {
1290         dim3 threads{16, 32};
1291         int blocks = (N + threads.x - 1) / threads.x;
1292         size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
1293         GammaBetaBackwardCUDAKernel<T, T_ACC>
1294             <<<blocks, threads, shmem_sz, cuda_stream>>>(
1295                 M,
1296                 N,
1297                 dY_data,
1298                 X_data,
1299                 mean_data,
1300                 rstd_data,
1301                 dgamma_data,
1302                 dbeta_data);
1303         C10_CUDA_KERNEL_LAUNCH_CHECK();
1304       }
1305 #endif
1306     }
1307   }
1308 }
1309 
LayerNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)1310 void LayerNormBackwardKernelImpl(
1311     const Tensor& dY,
1312     const Tensor& X,
1313     const Tensor& mean,
1314     const Tensor& rstd,
1315     const Tensor& gamma,
1316     int64_t M,
1317     int64_t N,
1318     Tensor* dX,
1319     Tensor* dgamma,
1320     Tensor* dbeta) {
1321   AT_DISPATCH_FLOATING_TYPES_AND2(
1322       at::ScalarType::Half,
1323       at::ScalarType::BFloat16,
1324       X.scalar_type(),
1325       "LayerNormBackwardKernelImpl",
1326       [&]() {
1327         LayerNormBackwardKernelImplInternal<scalar_t>(
1328             dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
1329       });
1330 }
1331 
1332 } // namespace
1333 
layer_norm_cuda(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)1334 std::tuple<Tensor, Tensor, Tensor> layer_norm_cuda(
1335     const Tensor& input,
1336     IntArrayRef normalized_shape,
1337     const std::optional<Tensor>& weight_opt /* optional */,
1338     const std::optional<Tensor>& bias_opt /* optional */,
1339     double eps) {
1340   // See [Note: hacky wrapper removal for optional tensor]
1341   c10::MaybeOwned<Tensor> weight_maybe_owned =
1342       at::borrow_from_optional_tensor(weight_opt);
1343   const Tensor& weight = *weight_maybe_owned;
1344   c10::MaybeOwned<Tensor> bias_maybe_owned =
1345       at::borrow_from_optional_tensor(bias_opt);
1346   const Tensor& bias = *bias_maybe_owned;
1347 
1348   auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
1349   auto M = M_N.first;
1350   auto N = M_N.second;
1351   auto X = input.expect_contiguous();
1352   auto gamma = weight.expect_contiguous();
1353   auto beta = bias.expect_contiguous();
1354 
1355   Tensor Y = at::native::empty_like(
1356       *X,
1357       std::nullopt /* dtype */,
1358       std::nullopt /* layout */,
1359       std::nullopt /* device */,
1360       std::nullopt /* pin_memory */,
1361       LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1362   auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
1363   Tensor mean = at::empty({M}, X->options().dtype(acc_type));
1364   Tensor rstd = at::empty({M}, X->options().dtype(acc_type));
1365   // Calling the kernel for M==0 gives a CUDA error
1366   // See: https://github.com/pytorch/pytorch/pull/28614
1367   if (M > 0) {
1368     LayerNormKernelImpl(*X, *gamma, *beta, M, N, eps, &Y, &mean, &rstd);
1369   }
1370   const auto input_shape = input.sizes();
1371   const size_t axis = input.dim() - normalized_shape.size();
1372 
1373   std::vector<int64_t> stat_shape;
1374   for (const auto idx: c10::irange(axis)) {
1375     stat_shape.push_back(input_shape[idx]);
1376   }
1377   for (const auto C10_UNUSED idx: c10::irange(axis, input.dim())) {
1378     stat_shape.push_back(1);
1379   }
1380 
1381   mean = mean.view(stat_shape);
1382   rstd = rstd.view(stat_shape);
1383 
1384   return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
1385 }
1386 
layer_norm_backward_cuda(const Tensor & dY,const Tensor & input,IntArrayRef normalized_shape,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,std::array<bool,3> grad_input_mask)1387 std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cuda(
1388     const Tensor& dY,
1389     const Tensor& input,
1390     IntArrayRef normalized_shape,
1391     const Tensor& mean,
1392     const Tensor& rstd,
1393     const std::optional<Tensor>& weight_opt /* optional */,
1394     const std::optional<Tensor>& bias_opt /* optional */,
1395     std::array<bool, 3> grad_input_mask) {
1396   // See [Note: hacky wrapper removal for optional tensor]
1397   c10::MaybeOwned<Tensor> weight_maybe_owned =
1398       at::borrow_from_optional_tensor(weight_opt);
1399   const Tensor& weight = *weight_maybe_owned;
1400   c10::MaybeOwned<Tensor> bias_maybe_owned =
1401       at::borrow_from_optional_tensor(bias_opt);
1402   const Tensor& bias = *bias_maybe_owned;
1403 
1404   auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
1405   auto M = M_N.first;
1406   auto N = M_N.second;
1407   auto X = input.expect_contiguous();
1408   auto gamma = weight.expect_contiguous();
1409   auto beta = bias.expect_contiguous();
1410 
1411   Tensor dX;
1412   Tensor dgamma;
1413   Tensor dbeta;
1414   if (grad_input_mask[0]) {
1415     dX = at::native::empty_like(
1416         *X,
1417         std::nullopt /* dtype */,
1418         std::nullopt /* layout */,
1419         std::nullopt /* device */,
1420         std::nullopt /* pin_memory */,
1421         LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1422   }
1423   if (grad_input_mask[1]) {
1424     dgamma = M > 0 ? at::native::empty_like(
1425                          *gamma,
1426                          std::nullopt /* dtype */,
1427                          std::nullopt /* layout */,
1428                          std::nullopt /* device */,
1429                          std::nullopt /* pin_memory */,
1430                          LEGACY_CONTIGUOUS_MEMORY_FORMAT)
1431                    : at::native::zeros_like(
1432                          *gamma,
1433                          std::nullopt /* dtype */,
1434                          std::nullopt /* layout */,
1435                          std::nullopt /* device */,
1436                          std::nullopt /* pin_memory */,
1437                          LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1438   }
1439   if (grad_input_mask[2]) {
1440     dbeta = M > 0 ? at::native::empty_like(
1441                         *beta,
1442                         std::nullopt /* dtype */,
1443                         std::nullopt /* layout */,
1444                         std::nullopt /* device */,
1445                         std::nullopt /* pin_memory */,
1446                         LEGACY_CONTIGUOUS_MEMORY_FORMAT)
1447                   : at::native::zeros_like(
1448                         *beta,
1449                         std::nullopt /* dtype */,
1450                         std::nullopt /* layout */,
1451                         std::nullopt /* device */,
1452                         std::nullopt /* pin_memory */,
1453                         LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1454   }
1455   if (M > 0 && N > 0) {
1456     LayerNormBackwardKernelImpl(
1457         dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
1458   }
1459   return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
1460 }
1461 
1462 REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl);
1463 REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl);
1464 
1465 } // namespace at::native
1466