xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/group_norm_kernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/group_norm.h>
3 
4 #include <type_traits>
5 
6 #include <thrust/tuple.h>
7 
8 #include <ATen/core/Tensor.h>
9 #include <ATen/AccumulateType.h>
10 #include <ATen/Dispatch.h>
11 #include <ATen/native/SharedReduceOps.h>
12 #include <ATen/native/TensorIterator.h>
13 #include <c10/cuda/CUDAMathCompat.h>
14 #include <ATen/cuda/detail/IndexUtils.cuh>
15 #include <ATen/native/cuda/Loops.cuh>
16 #include <ATen/native/cuda/block_reduce.cuh>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #endif
23 
24 namespace at::native {
25 
26 namespace {
27 
28 constexpr int kCUDANumThreads = 256;
29 constexpr int kReduceTileSize = 32;
30 
31 template <typename T>
RowwiseMomentsCUDAKernel(int64_t N,T eps,const T * X,T * mean,T * rstd)32 __global__ void RowwiseMomentsCUDAKernel(
33     int64_t N,
34     T eps,
35     const T* X,
36     T* mean,
37     T* rstd) {
38   using T_ACC = acc_type<T, true>;
39   using WelfordType = WelfordData<T_ACC, int64_t>;
40   using WelfordOp =
41       WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
42 
43   const int64_t i = blockIdx.x;
44   WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
45   WelfordType val(0, 0, 0, 0);
46   for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
47     const int64_t index = i * N + j;
48     val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
49   }
50   if (blockDim.x <= C10_WARP_SIZE) {
51     val = cuda_utils::WarpReduce(val, welford_op);
52   } else {
53     // There will be a warning if we declare a __shared__ WelfordType array.
54     // https://github.com/pytorch/pytorch/pull/13967
55     __shared__ typename std::aligned_storage<
56         sizeof(WelfordType),
57         alignof(WelfordType)>::type val_shared[C10_WARP_SIZE];
58     WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
59     val = cuda_utils::BlockReduce(
60         val,
61         welford_op,
62         /*identity_element=*/WelfordType(0, 0, 0, 0),
63         val_shared_ptr);
64   }
65   if (threadIdx.x == 0) {
66     T_ACC m1;
67     T_ACC m2;
68     thrust::tie(m2, m1) = welford_op.project(val);
69     mean[i] = m1;
70     rstd[i] = c10::cuda::compat::rsqrt(m2 + static_cast<T_ACC>(eps));
71   }
72 }
73 
74 template <typename T>
ComputeFusedParamsCUDAKernel(int64_t N,int64_t C,int64_t group,const T * mean,const T * rstd,const T * gamma,const T * beta,acc_type<T,true> * a,acc_type<T,true> * b)75 __global__ void ComputeFusedParamsCUDAKernel(
76     int64_t N,
77     int64_t C,
78     int64_t group,
79     const T* mean,
80     const T* rstd,
81     const T* gamma,
82     const T* beta,
83     acc_type<T, true>* a,
84     acc_type<T, true>* b) {
85   using T_ACC = acc_type<T, true>;
86   const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
87   if (index < N * C) {
88     const int64_t ng = index / (C / group);
89     const int64_t c = index % C;
90     const T_ACC scale = (gamma == nullptr)
91         ? static_cast<T_ACC>(rstd[ng])
92         : static_cast<T_ACC>(rstd[ng]) * static_cast<T_ACC>(gamma[c]);
93     a[index] = scale;
94     b[index] = -scale * static_cast<T_ACC>(mean[ng]) +
95         ((beta == nullptr) ? 0 : static_cast<T_ACC>(beta[c]));
96   }
97 }
98 
99 template <typename T>
Compute1dBackwardFusedParamsCUDAKernel(int64_t C,int64_t group,const T * dY,const T * X,const T * mean,const T * rstd,const T * gamma,acc_type<T,true> * c2,acc_type<T,true> * c3)100 __global__ void Compute1dBackwardFusedParamsCUDAKernel(
101     int64_t C,
102     int64_t group,
103     const T* dY,
104     const T* X,
105     const T* mean,
106     const T* rstd,
107     const T* gamma,
108     acc_type<T, true>* c2,
109     acc_type<T, true>* c3) {
110   using T_ACC = acc_type<T, true>;
111   const int64_t G = group;
112   const int64_t D = C / G;
113   const int64_t n = blockIdx.x;
114   const int64_t g = blockIdx.y;
115   const int64_t ng = n * G + g;
116   T_ACC sum1 = 0;
117   T_ACC sum2 = 0;
118   for (int64_t i = threadIdx.x; i < D; i += blockDim.x) {
119     const int64_t index = ng * D + i;
120     const int64_t c = g * D + i;
121     const T_ACC gamma_v =
122         gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[c]);
123     sum1 += dY[index] * X[index] * gamma_v;
124     sum2 += dY[index] * gamma_v;
125   }
126   if (blockDim.x <= C10_WARP_SIZE) {
127     sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
128     sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
129   } else {
130     __shared__ T_ACC ds_shared[C10_WARP_SIZE];
131     __shared__ T_ACC db_shared[C10_WARP_SIZE];
132     sum1 = cuda_utils::BlockReduceSum<T_ACC>(sum1, ds_shared);
133     sum2 = cuda_utils::BlockReduceSum<T_ACC>(sum2, db_shared);
134   }
135   if (threadIdx.x == 0) {
136     const T_ACC s = T_ACC(1) / static_cast<T_ACC>(D);
137     const T_ACC x = (sum2 * static_cast<T_ACC>(mean[ng]) - sum1) *
138         static_cast<T_ACC>(rstd[ng]) * static_cast<T_ACC>(rstd[ng]) *
139         static_cast<T_ACC>(rstd[ng]) * s;
140     c2[ng] = x;
141     c3[ng] = -x * static_cast<T_ACC>(mean[ng]) -
142         sum2 * static_cast<T_ACC>(rstd[ng]) * s;
143   }
144 }
145 
146 template <typename T>
GammaBeta1dBackwardCUDAKernel1(int64_t N,int64_t C,int64_t group,const T * dY,const T * X,const T * mean,const T * rstd,T * dgamma,T * dbeta)147 __global__ void GammaBeta1dBackwardCUDAKernel1(
148     int64_t N,
149     int64_t C,
150     int64_t group,
151     const T* dY,
152     const T* X,
153     const T* mean,
154     const T* rstd,
155     T* dgamma,
156     T* dbeta) {
157   using T_ACC = acc_type<T, true>;
158   const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
159   if (c < C) {
160     const int64_t G = group;
161     const int64_t D = C / G;
162     T_ACC sum1 = 0;
163     T_ACC sum2 = 0;
164     for (int64_t n = 0; n < N; ++n) {
165       const int64_t nc = n * C + c;
166       const int64_t ng = n * G + c / D;
167       const T_ACC dy_acc = static_cast<T_ACC>(dY[nc]);
168       const T_ACC x_acc = static_cast<T_ACC>(X[nc]);
169       sum1 += (dgamma == nullptr)
170           ? T_ACC(0)
171           : ((dy_acc * x_acc - dy_acc * static_cast<T_ACC>(mean[ng])) *
172              static_cast<T_ACC>(rstd[ng]));
173       sum2 += (dbeta == nullptr) ? T_ACC(0) : dy_acc;
174     }
175     if (dgamma != nullptr) {
176       dgamma[c] = sum1;
177     }
178     if (dbeta != nullptr) {
179       dbeta[c] = sum2;
180     }
181   }
182 }
183 
184 template <typename T>
GammaBeta1dBackwardCUDAKernel2(int64_t N,int64_t C,int64_t group,const T * dY,const T * X,const T * mean,const T * rstd,T * dgamma,T * dbeta)185 __global__ void GammaBeta1dBackwardCUDAKernel2(
186     int64_t N,
187     int64_t C,
188     int64_t group,
189     const T* dY,
190     const T* X,
191     const T* mean,
192     const T* rstd,
193     T* dgamma,
194     T* dbeta) {
195   using T_ACC = acc_type<T, true>;
196   __shared__ T_ACC g_shared[kReduceTileSize][kReduceTileSize + 1];
197   __shared__ T_ACC b_shared[kReduceTileSize][kReduceTileSize + 1];
198   const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
199   T_ACC dg_sum1 = 0;
200   T_ACC dg_sum2 = 0;
201   T_ACC db_sum1 = 0;
202   T_ACC db_sum2 = 0;
203   if (c < C) {
204     const int64_t G = group;
205     const int64_t D = C / G;
206     // Accumulate each 32 cols into a 32 * 32 tile.
207     // Since the blockDim is (32, 16), accumulate twice for 1st and 2nd 16 rows
208     // of a 32 contiguous elements.
209     for (int64_t n = threadIdx.y; n < N; n += blockDim.y * 2) {
210       const int64_t n1 = n;
211       const int64_t n2 = n + blockDim.y;
212       const int64_t nc1 = n1 * C + c;
213       const int64_t nc2 = n2 * C + c;
214       const int64_t ng1 = n1 * G + c / D;
215       const int64_t ng2 = n2 * G + c / D;
216       const T_ACC dy1_acc = static_cast<T_ACC>(dY[nc1]);
217       const T_ACC x1_acc = static_cast<T_ACC>(X[nc1]);
218       dg_sum1 += dgamma == nullptr
219           ? T_ACC(0)
220           : ((dy1_acc * x1_acc - dy1_acc * static_cast<T_ACC>(mean[ng1])) *
221              static_cast<T_ACC>(rstd[ng1]));
222       db_sum1 += dbeta == nullptr ? T_ACC(0) : dy1_acc;
223       if (n2 < N) {
224         const T_ACC dy2_acc = static_cast<T_ACC>(dY[nc2]);
225         const T_ACC x2_acc = static_cast<T_ACC>(X[nc2]);
226         dg_sum2 += dgamma == nullptr
227             ? T_ACC(0)
228             : ((dy2_acc * x2_acc - dy2_acc * static_cast<T_ACC>(mean[ng2])) *
229                static_cast<T_ACC>(rstd[ng2]));
230         db_sum2 += dbeta == nullptr ? T_ACC(0) : dy2_acc;
231       }
232     }
233   }
234 
235   // Write accumulated tile to shared memory.
236   g_shared[threadIdx.y][threadIdx.x] = dg_sum1;
237   g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2;
238   b_shared[threadIdx.y][threadIdx.x] = db_sum1;
239   b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2;
240   __syncthreads();
241 
242   // Do warp reduce for the 1st 16 cols in the tile.
243   T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y];
244   T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y];
245   sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
246   sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
247   if (threadIdx.x == 0) {
248     const int64_t c = blockIdx.x * blockDim.x + threadIdx.y;
249     if (c < C) {
250       if (dgamma != nullptr) {
251         dgamma[c] = sum1;
252       }
253       if (dbeta != nullptr) {
254         dbeta[c] = sum2;
255       }
256     }
257   }
258 
259   // Do warp reduce for the 2nd 16 cols in the tile.
260   sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y];
261   sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y];
262   sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
263   sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
264   if (threadIdx.x == 0) {
265     const int64_t c = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y;
266     if (c < C) {
267       if (dgamma != nullptr) {
268         dgamma[c] = sum1;
269       }
270       if (dbeta != nullptr) {
271         dbeta[c] = sum2;
272       }
273     }
274   }
275 }
276 
277 template <typename T>
ComputeInternalGradientsCUDAKernel(int64_t HxW,const T * dY,const T * X,acc_type<T,true> * ds,acc_type<T,true> * db)278 __global__ void ComputeInternalGradientsCUDAKernel(
279     int64_t HxW,
280     const T* dY,
281     const T* X,
282     acc_type<T, true>* ds,
283     acc_type<T, true>* db) {
284   using T_ACC = acc_type<T, true>;
285   const int64_t nc = blockIdx.x;
286   T_ACC sum1 = 0;
287   T_ACC sum2 = 0;
288   for (int64_t hw = threadIdx.x; hw < HxW; hw += blockDim.x) {
289     const int64_t index = nc * HxW + hw;
290     sum1 += static_cast<T_ACC>(dY[index]) * static_cast<T_ACC>(X[index]);
291     sum2 += static_cast<T_ACC>(dY[index]);
292   }
293   if (blockDim.x <= C10_WARP_SIZE) {
294     sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
295     sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
296   } else {
297     __shared__ T_ACC ds_shared[C10_WARP_SIZE];
298     __shared__ T_ACC db_shared[C10_WARP_SIZE];
299     sum1 = cuda_utils::BlockReduceSum<T_ACC>(sum1, ds_shared);
300     sum2 = cuda_utils::BlockReduceSum<T_ACC>(sum2, db_shared);
301   }
302   if (threadIdx.x == 0) {
303     ds[nc] = sum1;
304     db[nc] = sum2;
305   }
306 }
307 
308 template <typename T>
ComputeBackwardFusedParamsCUDAKernel(int64_t C,int64_t HxW,int64_t group,const T * mean,const T * rstd,const T * gamma,const acc_type<T,true> * ds,const acc_type<T,true> * db,acc_type<T,true> * c2,acc_type<T,true> * c3)309 __global__ void ComputeBackwardFusedParamsCUDAKernel(
310     int64_t C,
311     int64_t HxW,
312     int64_t group,
313     const T* mean,
314     const T* rstd,
315     const T* gamma,
316     const acc_type<T, true>* ds,
317     const acc_type<T, true>* db,
318     acc_type<T, true>* c2,
319     acc_type<T, true>* c3) {
320   using T_ACC = acc_type<T, true>;
321   const int64_t G = group;
322   const int64_t D = C / G;
323   const int64_t n = blockIdx.x;
324   const int64_t g = blockIdx.y;
325   const int64_t ng = n * G + g;
326   T_ACC sum1 = 0;
327   T_ACC sum2 = 0;
328   for (int64_t i = threadIdx.x; i < D; i += blockDim.x) {
329     const int64_t index = ng * D + i;
330     const int64_t c = g * D + i;
331     const T_ACC gamma_v =
332         gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[c]);
333     sum1 += ds[index] * gamma_v;
334     sum2 += db[index] * gamma_v;
335   }
336   if (blockDim.x <= C10_WARP_SIZE) {
337     sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
338     sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
339   } else {
340     __shared__ T_ACC ds_shared[C10_WARP_SIZE];
341     __shared__ T_ACC db_shared[C10_WARP_SIZE];
342     sum1 = cuda_utils::BlockReduceSum<T_ACC>(sum1, ds_shared);
343     sum2 = cuda_utils::BlockReduceSum<T_ACC>(sum2, db_shared);
344   }
345   if (threadIdx.x == 0) {
346     const T_ACC s = T_ACC(1) / static_cast<T_ACC>(D * HxW);
347     const T_ACC x = (sum2 * static_cast<T_ACC>(mean[ng]) - sum1) *
348         static_cast<T_ACC>(rstd[ng]) * static_cast<T_ACC>(rstd[ng]) *
349         static_cast<T_ACC>(rstd[ng]) * s;
350     c2[ng] = x;
351     c3[ng] = -x * static_cast<T_ACC>(mean[ng]) -
352         sum2 * static_cast<T_ACC>(rstd[ng]) * s;
353   }
354 }
355 
356 template <typename T>
GammaBetaBackwardCUDAKernel1(int64_t N,int64_t C,int64_t group,const T * mean,const T * rstd,const acc_type<T,true> * ds,const acc_type<T,true> * db,T * dgamma,T * dbeta)357 __global__ void GammaBetaBackwardCUDAKernel1(
358     int64_t N,
359     int64_t C,
360     int64_t group,
361     const T* mean,
362     const T* rstd,
363     const acc_type<T, true>* ds,
364     const acc_type<T, true>* db,
365     T* dgamma,
366     T* dbeta) {
367   using T_ACC = acc_type<T, true>;
368   const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
369   if (c < C) {
370     const int64_t G = group;
371     const int64_t D = C / G;
372     T_ACC sum1 = 0;
373     T_ACC sum2 = 0;
374     for (int64_t n = 0; n < N; ++n) {
375       const int64_t nc = n * C + c;
376       const int64_t ng = n * G + c / D;
377       sum1 += (dgamma == nullptr)
378           ? T_ACC(0)
379           : ((ds[nc] - db[nc] * static_cast<T_ACC>(mean[ng])) *
380              static_cast<T_ACC>(rstd[ng]));
381       sum2 += (dbeta == nullptr) ? T_ACC(0) : db[nc];
382     }
383     if (dgamma != nullptr) {
384       dgamma[c] = sum1;
385     }
386     if (dbeta != nullptr) {
387       dbeta[c] = sum2;
388     }
389   }
390 }
391 
392 template <typename T>
GammaBetaBackwardCUDAKernel2(int64_t N,int64_t C,int64_t group,const T * mean,const T * rstd,const acc_type<T,true> * ds,const acc_type<T,true> * db,T * dgamma,T * dbeta)393 __global__ void GammaBetaBackwardCUDAKernel2(
394     int64_t N,
395     int64_t C,
396     int64_t group,
397     const T* mean,
398     const T* rstd,
399     const acc_type<T, true>* ds,
400     const acc_type<T, true>* db,
401     T* dgamma,
402     T* dbeta) {
403   using T_ACC = acc_type<T, true>;
404   __shared__ T_ACC g_shared[kReduceTileSize][kReduceTileSize + 1];
405   __shared__ T_ACC b_shared[kReduceTileSize][kReduceTileSize + 1];
406   const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
407   T_ACC dg_sum1 = 0;
408   T_ACC dg_sum2 = 0;
409   T_ACC db_sum1 = 0;
410   T_ACC db_sum2 = 0;
411   if (c < C) {
412     const int64_t G = group;
413     const int64_t D = C / G;
414     // Accumulate each 32 cols into a 32 * 32 tile.
415     // Since the blockDim is (32, 16), accumulate twice for 1st and 2nd 16 rows
416     // of a 32 contiguous elements.
417     for (int64_t n = threadIdx.y; n < N; n += blockDim.y * 2) {
418       const int64_t n1 = n;
419       const int64_t n2 = n + blockDim.y;
420       const int64_t nc1 = n1 * C + c;
421       const int64_t nc2 = n2 * C + c;
422       const int64_t ng1 = n1 * G + c / D;
423       const int64_t ng2 = n2 * G + c / D;
424       dg_sum1 += dgamma == nullptr
425           ? T_ACC(0)
426           : ((ds[nc1] - db[nc1] * static_cast<T_ACC>(mean[ng1])) *
427              static_cast<T_ACC>(rstd[ng1]));
428       db_sum1 += dbeta == nullptr ? T_ACC(0) : db[nc1];
429       if (n2 < N) {
430         dg_sum2 += dgamma == nullptr
431             ? T_ACC(0)
432             : ((ds[nc2] - db[nc2] * static_cast<T_ACC>(mean[ng2])) *
433                static_cast<T_ACC>(rstd[ng2]));
434         db_sum2 += dbeta == nullptr ? T_ACC(0) : db[nc2];
435       }
436     }
437   }
438 
439   // Write accumulated tile to shared memory.
440   g_shared[threadIdx.y][threadIdx.x] = dg_sum1;
441   g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2;
442   b_shared[threadIdx.y][threadIdx.x] = db_sum1;
443   b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2;
444   __syncthreads();
445 
446   // Do warp reduce for the 1st 16 cols in the tile.
447   T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y];
448   T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y];
449   sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
450   sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
451   if (threadIdx.x == 0) {
452     const int64_t c = blockIdx.x * blockDim.x + threadIdx.y;
453     if (c < C) {
454       if (dgamma != nullptr) {
455         dgamma[c] = sum1;
456       }
457       if (dbeta != nullptr) {
458         dbeta[c] = sum2;
459       }
460     }
461   }
462 
463   // Do warp reduce for the 2st 16 cols in the tile.
464   sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y];
465   sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y];
466   sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
467   sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
468   if (threadIdx.x == 0) {
469     const int64_t c = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y;
470     if (c < C) {
471       if (dgamma != nullptr) {
472         dgamma[c] = sum1;
473       }
474       if (dbeta != nullptr) {
475         dbeta[c] = sum2;
476       }
477     }
478   }
479 }
480 
481 template <typename T>
GroupNorm1dForward(const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t group,Tensor & Y)482 void GroupNorm1dForward(
483     const Tensor& X,
484     const Tensor& mean,
485     const Tensor& rstd,
486     const Tensor& gamma,
487     const Tensor& beta,
488     int64_t N,
489     int64_t C,
490     int64_t group,
491     Tensor& Y) {
492   using T_ACC = acc_type<T, true>;
493   const int64_t G = group;
494   const int64_t D = C / G;
495   if (gamma.defined() && beta.defined()) {
496     auto iter = TensorIteratorConfig()
497                     .resize_outputs(false)
498                     .add_owned_output(Y.view({N, G, D}))
499                     .add_owned_const_input(X.view({N, G, D}))
500                     .add_owned_input(mean.view({N, G, 1}))
501                     .add_owned_input(rstd.view({N, G, 1}))
502                     .add_owned_const_input(gamma.view({1, G, D}))
503                     .add_owned_const_input(beta.view({1, G, D}))
504                     .build();
505     gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma, T beta) -> T {
506       return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
507           static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma) +
508           static_cast<T_ACC>(beta);
509     });
510   } else if (gamma.defined()) {
511     auto iter = TensorIteratorConfig()
512                     .resize_outputs(false)
513                     .add_owned_output(Y.view({N, G, D}))
514                     .add_owned_const_input(X.view({N, G, D}))
515                     .add_owned_input(mean.view({N, G, 1}))
516                     .add_owned_input(rstd.view({N, G, 1}))
517                     .add_owned_const_input(gamma.view({1, G, D}))
518                     .build();
519     gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma) -> T {
520       return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
521           static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma);
522     });
523   } else if (beta.defined()) {
524     auto iter = TensorIteratorConfig()
525                     .resize_outputs(false)
526                     .add_owned_output(Y.view({N, G, D}))
527                     .add_owned_const_input(X.view({N, G, D}))
528                     .add_owned_input(mean.view({N, G, 1}))
529                     .add_owned_input(rstd.view({N, G, 1}))
530                     .add_owned_const_input(beta.view({1, G, D}))
531                     .build();
532     gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T beta) -> T {
533       return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
534           static_cast<T_ACC>(rstd) +
535           static_cast<T_ACC>(beta);
536     });
537   } else {
538     auto iter = TensorIteratorConfig()
539                     .resize_outputs(false)
540                     .add_owned_output(Y.view({N * G, D}))
541                     .add_owned_const_input(X.view({N * G, D}))
542                     .add_owned_input(mean.view({N * G, 1}))
543                     .add_owned_input(rstd.view({N * G, 1}))
544                     .build();
545     gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T {
546       return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
547           static_cast<T_ACC>(rstd);
548     });
549   }
550   AT_CUDA_CHECK(cudaGetLastError());
551 }
552 
553 template <typename T>
GroupNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,T eps,Tensor & Y,Tensor & mean,Tensor & rstd)554 void GroupNormKernelImplInternal(
555     const Tensor& X,
556     const Tensor& gamma,
557     const Tensor& beta,
558     int64_t N,
559     int64_t C,
560     int64_t HxW,
561     int64_t group,
562     T eps,
563     Tensor& Y,
564     Tensor& mean,
565     Tensor& rstd) {
566   using T_ACC = acc_type<T, true>;
567   TORCH_CHECK(X.numel() == N * C * HxW);
568   TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
569   TORCH_CHECK(!beta.defined() || beta.numel() == C);
570   if (N == 0) {
571     return;
572   }
573   const int64_t G = group;
574   const int64_t D = C / G;
575   const T* X_data = X.const_data_ptr<T>();
576   T* mean_data = mean.mutable_data_ptr<T>();
577   T* rstd_data = rstd.mutable_data_ptr<T>();
578 
579   cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
580   const int64_t num_threads = D * HxW < cuda_utils::kCUDABlockReduceNumThreads
581       ? at::cuda::warp_size()
582       : cuda_utils::kCUDABlockReduceNumThreads;
583   RowwiseMomentsCUDAKernel<T><<<N * G, num_threads, 0, cuda_stream>>>(
584       D * HxW, eps, X_data, mean_data, rstd_data);
585   C10_CUDA_KERNEL_LAUNCH_CHECK();
586 
587   if (HxW == 1) {
588     GroupNorm1dForward<T>(X, mean, rstd, gamma, beta, N, C, G, Y);
589   } else if (!gamma.defined() && !beta.defined()) {
590     auto iter = TensorIteratorConfig()
591                     .resize_outputs(false)
592                     .add_owned_output(Y.view({N * G, D * HxW}))
593                     .add_owned_const_input(X.view({N * G, D * HxW}))
594                     .add_owned_input(mean.view({N * G, 1}))
595                     .add_owned_input(rstd.view({N * G, 1}))
596                     .build();
597     gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T {
598       return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
599           static_cast<T_ACC>(rstd);
600     });
601   } else {
602     const auto kAccType =
603         (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
604         ? kFloat
605         : X.scalar_type();
606     Tensor a = at::empty({N, C}, X.options().dtype(kAccType));
607     Tensor b = at::empty({N, C}, X.options().dtype(kAccType));
608     const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
609     const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
610     T_ACC* a_data = a.mutable_data_ptr<T_ACC>();
611     T_ACC* b_data = b.mutable_data_ptr<T_ACC>();
612 
613     // TODO: Since there is some issues in gpu_kernel_multiple_outputs, we are
614     // using manual kernel here. Make it using gpu_kernel_multiple_outputs once
615     // the issue fixed.
616     const int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads;
617     ComputeFusedParamsCUDAKernel<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
618         N, C, G, mean_data, rstd_data, gamma_data, beta_data, a_data, b_data);
619     C10_CUDA_KERNEL_LAUNCH_CHECK();
620 
621     auto iter = TensorIteratorConfig()
622                     .check_all_same_dtype(std::is_same<T, T_ACC>::value)
623                     .resize_outputs(false)
624                     .add_owned_output(Y.view({N * C, HxW}))
625                     .add_owned_const_input(X.view({N * C, HxW}))
626                     .add_owned_input(a.view({N * C, 1}))
627                     .add_owned_input(b.view({N * C, 1}))
628                     .build();
629     gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
630       return a * static_cast<T_ACC>(x) + b;
631     });
632   }
633   AT_CUDA_CHECK(cudaGetLastError());
634 }
635 
GroupNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)636 void GroupNormKernelImpl(
637     const Tensor& X,
638     const Tensor& gamma,
639     const Tensor& beta,
640     int64_t N,
641     int64_t C,
642     int64_t HxW,
643     int64_t group,
644     double eps,
645     Tensor& Y,
646     Tensor& mean,
647     Tensor& rstd) {
648   AT_DISPATCH_FLOATING_TYPES_AND2(
649       at::ScalarType::Half,
650       at::ScalarType::BFloat16,
651       X.scalar_type(),
652       "GroupNormKernelImpl",
653       [&]() {
654         GroupNormKernelImplInternal<scalar_t>(
655             X,
656             gamma,
657             beta,
658             N,
659             C,
660             HxW,
661             group,
662             static_cast<scalar_t>(eps),
663             Y,
664             mean,
665             rstd);
666       });
667 }
668 
669 template <typename T>
GroupNorm1dBackward(const Tensor dY,const Tensor X,const Tensor mean,const Tensor rstd,const Tensor gamma,int64_t N,int64_t C,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)670 void GroupNorm1dBackward(
671     const Tensor dY,
672     const Tensor X,
673     const Tensor mean,
674     const Tensor rstd,
675     const Tensor gamma,
676     int64_t N,
677     int64_t C,
678     int64_t group,
679     Tensor& dX,
680     Tensor& dgamma,
681     Tensor& dbeta) {
682   using T_ACC = acc_type<T, true>;
683   const int64_t G = group;
684   const int64_t D = C / G;
685   const T* dY_data = dY.const_data_ptr<T>();
686   const T* X_data = X.const_data_ptr<T>();
687   const T* mean_data = mean.const_data_ptr<T>();
688   const T* rstd_data = rstd.const_data_ptr<T>();
689 
690   cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
691   if (dX.defined()) {
692     const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
693     const auto kAccType =
694         (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
695         ? kFloat
696         : X.scalar_type();
697     Tensor c2 = at::empty({N, G}, X.options().dtype(kAccType));
698     Tensor c3 = at::empty({N, G}, X.options().dtype(kAccType));
699     T_ACC* c2_data = c2.mutable_data_ptr<T_ACC>();
700     T_ACC* c3_data = c3.mutable_data_ptr<T_ACC>();
701     const int64_t num_threads = (C / G) < cuda_utils::kCUDABlockReduceNumThreads
702         ? at::cuda::warp_size()
703         : cuda_utils::kCUDABlockReduceNumThreads;
704     Compute1dBackwardFusedParamsCUDAKernel<T>
705         <<<dim3(N, G), num_threads, 0, cuda_stream>>>(
706             C,
707             G,
708             dY_data,
709             X_data,
710             mean_data,
711             rstd_data,
712             gamma_data,
713             c2_data,
714             c3_data);
715     C10_CUDA_KERNEL_LAUNCH_CHECK();
716 
717     if (gamma.defined()) {
718       auto iter = TensorIteratorConfig()
719                       .check_all_same_dtype(std::is_same<T, T_ACC>::value)
720                       .resize_outputs(false)
721                       .add_owned_output(dX.view({N, G, D}))
722                       .add_owned_const_input(dY.view({N, G, D}))
723                       .add_owned_const_input(X.view({N, G, D}))
724                       .add_owned_const_input(rstd.view({N, G, 1}))
725                       .add_owned_const_input(gamma.view({1, G, D}))
726                       .add_owned_const_input(c2.view({N, G, 1}))
727                       .add_owned_const_input(c3.view({N, G, 1}))
728                       .build();
729       gpu_kernel(
730           iter,
731           [] GPU_LAMBDA(T dy, T x, T rstd, T gamma, T_ACC c2, T_ACC c3) -> T {
732             const T_ACC c1 =
733                 static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma);
734             return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
735                 c3;
736           });
737     } else {
738       auto iter = TensorIteratorConfig()
739                       .check_all_same_dtype(std::is_same<T, T_ACC>::value)
740                       .resize_outputs(false)
741                       .add_owned_output(dX.view({N * G, D}))
742                       .add_owned_const_input(dY.view({N * G, D}))
743                       .add_owned_const_input(X.view({N * G, D}))
744                       .add_owned_const_input(rstd.view({N * G, 1}))
745                       .add_owned_const_input(c2.view({N * G, 1}))
746                       .add_owned_const_input(c3.view({N * G, 1}))
747                       .build();
748       gpu_kernel(
749           iter, [] GPU_LAMBDA(T dy, T x, T rstd, T_ACC c2, T_ACC c3) -> T {
750             const T_ACC c1 = static_cast<T_ACC>(rstd);
751             return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
752                 c3;
753           });
754     }
755   }
756   if (dgamma.defined() || dbeta.defined()) {
757     T* dgamma_data = dgamma.defined() ? dgamma.mutable_data_ptr<T>() : nullptr;
758     T* dbeta_data = dbeta.defined() ? dbeta.mutable_data_ptr<T>() : nullptr;
759     if (N <= 128) {
760       const int64_t B = (C + kCUDANumThreads - 1) / kCUDANumThreads;
761       GammaBeta1dBackwardCUDAKernel1<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
762           N,
763           C,
764           G,
765           dY_data,
766           X_data,
767           mean_data,
768           rstd_data,
769           dgamma_data,
770           dbeta_data);
771       C10_CUDA_KERNEL_LAUNCH_CHECK();
772     } else {
773       const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize;
774       // The algorithm for colwise reduction here is to accumulate each 32 cols
775       // to a 32 * 32 tile and write the tile to shared memory. Then do warp
776       // reduce for each col in the tile. So here the blockDim must be (32, 16).
777       constexpr int kThreadX = kReduceTileSize;
778       constexpr int kThreadY = kReduceTileSize / 2;
779       GammaBeta1dBackwardCUDAKernel2<T>
780           <<<B, dim3(kThreadX, kThreadY), 0, cuda_stream>>>(
781               N,
782               C,
783               G,
784               dY_data,
785               X_data,
786               mean_data,
787               rstd_data,
788               dgamma_data,
789               dbeta_data);
790       C10_CUDA_KERNEL_LAUNCH_CHECK();
791     }
792   }
793 }
794 
795 template <typename T>
GroupNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)796 void GroupNormBackwardKernelImplInternal(
797     const Tensor& dY,
798     const Tensor& X,
799     const Tensor& mean,
800     const Tensor& rstd,
801     const Tensor& gamma,
802     int64_t N,
803     int64_t C,
804     int64_t HxW,
805     int64_t group,
806     Tensor& dX,
807     Tensor& dgamma,
808     Tensor& dbeta) {
809   using T_ACC = acc_type<T, true>;
810   const int64_t G = group;
811   const int64_t D = C / G;
812   TORCH_CHECK(dY.numel() == N * C * HxW);
813   TORCH_CHECK(X.numel() == N * C * HxW);
814   TORCH_CHECK(mean.numel() == N * G);
815   TORCH_CHECK(rstd.numel() == N * G);
816   TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
817   cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
818 
819   if (N == 0) {
820     if (dgamma.defined()) {
821       dgamma.fill_(T(0));
822     }
823     if (dbeta.defined()) {
824       dbeta.fill_(T(0));
825     }
826     return;
827   }
828 
829   const T* dY_data = dY.const_data_ptr<T>();
830   const T* X_data = X.const_data_ptr<T>();
831   const T* mean_data = mean.const_data_ptr<T>();
832   const T* rstd_data = rstd.const_data_ptr<T>();
833   const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
834   const auto kAccType =
835       (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
836       ? kFloat
837       : X.scalar_type();
838   Tensor ds = at::empty({N, C}, X.options().dtype(kAccType));
839   Tensor db = at::empty({N, C}, X.options().dtype(kAccType));
840   T_ACC* ds_data = ds.mutable_data_ptr<T_ACC>();
841   T_ACC* db_data = db.mutable_data_ptr<T_ACC>();
842 
843   if (HxW == 1) {
844     GroupNorm1dBackward<T>(
845         dY, X, mean, rstd, gamma, N, C, G, dX, dgamma, dbeta);
846     return;
847   }
848 
849   int warp_size = at::cuda::warp_size();
850   int64_t num_threads = HxW < cuda_utils::kCUDABlockReduceNumThreads
851       ? warp_size
852       : cuda_utils::kCUDABlockReduceNumThreads;
853   ComputeInternalGradientsCUDAKernel<T><<<N * C, num_threads, 0, cuda_stream>>>(
854       HxW, dY_data, X_data, ds_data, db_data);
855   C10_CUDA_KERNEL_LAUNCH_CHECK();
856 
857   if (dX.defined()) {
858     Tensor c1 = at::empty({0}, X.options().dtype(kAccType));
859     Tensor c2 = at::empty({N, G}, X.options().dtype(kAccType));
860     Tensor c3 = at::empty({N, G}, X.options().dtype(kAccType));
861     T_ACC* c2_data = c2.mutable_data_ptr<T_ACC>();
862     T_ACC* c3_data = c3.mutable_data_ptr<T_ACC>();
863 
864     if (gamma.defined()) {
865       auto iter = TensorIteratorConfig()
866                       .check_all_same_dtype(std::is_same<T, T_ACC>::value)
867                       .add_output(c1)
868                       .add_owned_const_input(rstd.view({N, G, 1}))
869                       .add_owned_const_input(gamma.view({1, G, D}))
870                       .build();
871       gpu_kernel(iter, [] GPU_LAMBDA(T rstd, T gamma) -> T_ACC {
872         return static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma);
873       });
874     }
875 
876     num_threads = (C / G) < cuda_utils::kCUDABlockReduceNumThreads
877         ? warp_size
878         : cuda_utils::kCUDABlockReduceNumThreads;
879     ComputeBackwardFusedParamsCUDAKernel<T>
880         <<<dim3(N, G), num_threads, 0, cuda_stream>>>(
881             C,
882             HxW,
883             G,
884             mean_data,
885             rstd_data,
886             gamma_data,
887             ds_data,
888             db_data,
889             c2_data,
890             c3_data);
891     C10_CUDA_KERNEL_LAUNCH_CHECK();
892 
893     if (gamma.defined()) {
894       auto iter = TensorIteratorConfig()
895                       .check_all_same_dtype(std::is_same<T, T_ACC>::value)
896                       .resize_outputs(false)
897                       .add_owned_output(dX.view({N * G, D, HxW}))
898                       .add_owned_const_input(dY.view({N * G, D, HxW}))
899                       .add_owned_const_input(X.view({N * G, D, HxW}))
900                       .add_owned_const_input(c1.view({N * G, D, 1}))
901                       .add_owned_const_input(c2.view({N * G, 1, 1}))
902                       .add_owned_const_input(c3.view({N * G, 1, 1}))
903                       .build();
904       gpu_kernel(
905           iter, [] GPU_LAMBDA(T dy, T x, T_ACC c1, T_ACC c2, T_ACC c3) -> T {
906             return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
907                 c3;
908           });
909     } else {
910       auto iter = TensorIteratorConfig()
911                       .check_all_same_dtype(std::is_same<T, T_ACC>::value)
912                       .resize_outputs(false)
913                       .add_owned_output(dX.view({N * G, D * HxW}))
914                       .add_owned_const_input(dY.view({N * G, D * HxW}))
915                       .add_owned_const_input(X.view({N * G, D * HxW}))
916                       .add_owned_const_input(rstd.view({N * G, 1}))
917                       .add_owned_const_input(c2.view({N * G, 1}))
918                       .add_owned_const_input(c3.view({N * G, 1}))
919                       .build();
920       gpu_kernel(
921           iter, [] GPU_LAMBDA(T dy, T x, T_ACC c1, T_ACC c2, T_ACC c3) -> T {
922             return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
923                 c3;
924           });
925     }
926   }
927   if (dgamma.defined() || dbeta.defined()) {
928     T* dgamma_data = dgamma.defined() ? dgamma.mutable_data_ptr<T>() : nullptr;
929     T* dbeta_data = dbeta.defined() ? dbeta.mutable_data_ptr<T>() : nullptr;
930     if (N <= 128) {
931       // For small batch size, do colwise reduce directly.
932       const int64_t B = (C + kCUDANumThreads - 1) / kCUDANumThreads;
933       GammaBetaBackwardCUDAKernel1<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
934           N,
935           C,
936           G,
937           mean_data,
938           rstd_data,
939           ds_data,
940           db_data,
941           dgamma_data,
942           dbeta_data);
943       C10_CUDA_KERNEL_LAUNCH_CHECK();
944     } else {
945       const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize;
946       // The algorithm for colwise reduction here is to accumulate each 32 cols
947       // to a 32 * 32 tile and write the tile to shared memory. Then do warp
948       // reduce for each col in the tile. So here the blockDim must be (32, 16).
949       constexpr int kThreadX = kReduceTileSize;
950       constexpr int kThreadY = kReduceTileSize / 2;
951       GammaBetaBackwardCUDAKernel2<T>
952           <<<B, dim3(kThreadX, kThreadY), 0, cuda_stream>>>(
953               N,
954               C,
955               G,
956               mean_data,
957               rstd_data,
958               ds_data,
959               db_data,
960               dgamma_data,
961               dbeta_data);
962       C10_CUDA_KERNEL_LAUNCH_CHECK();
963     }
964   }
965 }
966 
GroupNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)967 void GroupNormBackwardKernelImpl(
968     const Tensor& dY,
969     const Tensor& X,
970     const Tensor& mean,
971     const Tensor& rstd,
972     const Tensor& gamma,
973     int64_t N,
974     int64_t C,
975     int64_t HxW,
976     int64_t group,
977     Tensor& dX,
978     Tensor& dgamma,
979     Tensor& dbeta) {
980   AT_DISPATCH_FLOATING_TYPES_AND2(
981       at::ScalarType::Half,
982       at::ScalarType::BFloat16,
983       X.scalar_type(),
984       "GroupNormBackwardKernelImpl",
985       [&]() {
986         GroupNormBackwardKernelImplInternal<scalar_t>(
987             dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
988       });
989 }
990 
991 } // namespace
992 
993 REGISTER_DISPATCH(GroupNormKernel, &GroupNormKernelImpl);
994 REGISTER_DISPATCH(GroupNormBackwardKernel, &GroupNormBackwardKernelImpl);
995 
996 } // namespace at::native
997