1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/group_norm.h>
3
4 #include <type_traits>
5
6 #include <thrust/tuple.h>
7
8 #include <ATen/core/Tensor.h>
9 #include <ATen/AccumulateType.h>
10 #include <ATen/Dispatch.h>
11 #include <ATen/native/SharedReduceOps.h>
12 #include <ATen/native/TensorIterator.h>
13 #include <c10/cuda/CUDAMathCompat.h>
14 #include <ATen/cuda/detail/IndexUtils.cuh>
15 #include <ATen/native/cuda/Loops.cuh>
16 #include <ATen/native/cuda/block_reduce.cuh>
17
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #endif
23
24 namespace at::native {
25
26 namespace {
27
28 constexpr int kCUDANumThreads = 256;
29 constexpr int kReduceTileSize = 32;
30
31 template <typename T>
RowwiseMomentsCUDAKernel(int64_t N,T eps,const T * X,T * mean,T * rstd)32 __global__ void RowwiseMomentsCUDAKernel(
33 int64_t N,
34 T eps,
35 const T* X,
36 T* mean,
37 T* rstd) {
38 using T_ACC = acc_type<T, true>;
39 using WelfordType = WelfordData<T_ACC, int64_t>;
40 using WelfordOp =
41 WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
42
43 const int64_t i = blockIdx.x;
44 WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
45 WelfordType val(0, 0, 0, 0);
46 for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
47 const int64_t index = i * N + j;
48 val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
49 }
50 if (blockDim.x <= C10_WARP_SIZE) {
51 val = cuda_utils::WarpReduce(val, welford_op);
52 } else {
53 // There will be a warning if we declare a __shared__ WelfordType array.
54 // https://github.com/pytorch/pytorch/pull/13967
55 __shared__ typename std::aligned_storage<
56 sizeof(WelfordType),
57 alignof(WelfordType)>::type val_shared[C10_WARP_SIZE];
58 WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
59 val = cuda_utils::BlockReduce(
60 val,
61 welford_op,
62 /*identity_element=*/WelfordType(0, 0, 0, 0),
63 val_shared_ptr);
64 }
65 if (threadIdx.x == 0) {
66 T_ACC m1;
67 T_ACC m2;
68 thrust::tie(m2, m1) = welford_op.project(val);
69 mean[i] = m1;
70 rstd[i] = c10::cuda::compat::rsqrt(m2 + static_cast<T_ACC>(eps));
71 }
72 }
73
74 template <typename T>
ComputeFusedParamsCUDAKernel(int64_t N,int64_t C,int64_t group,const T * mean,const T * rstd,const T * gamma,const T * beta,acc_type<T,true> * a,acc_type<T,true> * b)75 __global__ void ComputeFusedParamsCUDAKernel(
76 int64_t N,
77 int64_t C,
78 int64_t group,
79 const T* mean,
80 const T* rstd,
81 const T* gamma,
82 const T* beta,
83 acc_type<T, true>* a,
84 acc_type<T, true>* b) {
85 using T_ACC = acc_type<T, true>;
86 const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
87 if (index < N * C) {
88 const int64_t ng = index / (C / group);
89 const int64_t c = index % C;
90 const T_ACC scale = (gamma == nullptr)
91 ? static_cast<T_ACC>(rstd[ng])
92 : static_cast<T_ACC>(rstd[ng]) * static_cast<T_ACC>(gamma[c]);
93 a[index] = scale;
94 b[index] = -scale * static_cast<T_ACC>(mean[ng]) +
95 ((beta == nullptr) ? 0 : static_cast<T_ACC>(beta[c]));
96 }
97 }
98
99 template <typename T>
Compute1dBackwardFusedParamsCUDAKernel(int64_t C,int64_t group,const T * dY,const T * X,const T * mean,const T * rstd,const T * gamma,acc_type<T,true> * c2,acc_type<T,true> * c3)100 __global__ void Compute1dBackwardFusedParamsCUDAKernel(
101 int64_t C,
102 int64_t group,
103 const T* dY,
104 const T* X,
105 const T* mean,
106 const T* rstd,
107 const T* gamma,
108 acc_type<T, true>* c2,
109 acc_type<T, true>* c3) {
110 using T_ACC = acc_type<T, true>;
111 const int64_t G = group;
112 const int64_t D = C / G;
113 const int64_t n = blockIdx.x;
114 const int64_t g = blockIdx.y;
115 const int64_t ng = n * G + g;
116 T_ACC sum1 = 0;
117 T_ACC sum2 = 0;
118 for (int64_t i = threadIdx.x; i < D; i += blockDim.x) {
119 const int64_t index = ng * D + i;
120 const int64_t c = g * D + i;
121 const T_ACC gamma_v =
122 gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[c]);
123 sum1 += dY[index] * X[index] * gamma_v;
124 sum2 += dY[index] * gamma_v;
125 }
126 if (blockDim.x <= C10_WARP_SIZE) {
127 sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
128 sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
129 } else {
130 __shared__ T_ACC ds_shared[C10_WARP_SIZE];
131 __shared__ T_ACC db_shared[C10_WARP_SIZE];
132 sum1 = cuda_utils::BlockReduceSum<T_ACC>(sum1, ds_shared);
133 sum2 = cuda_utils::BlockReduceSum<T_ACC>(sum2, db_shared);
134 }
135 if (threadIdx.x == 0) {
136 const T_ACC s = T_ACC(1) / static_cast<T_ACC>(D);
137 const T_ACC x = (sum2 * static_cast<T_ACC>(mean[ng]) - sum1) *
138 static_cast<T_ACC>(rstd[ng]) * static_cast<T_ACC>(rstd[ng]) *
139 static_cast<T_ACC>(rstd[ng]) * s;
140 c2[ng] = x;
141 c3[ng] = -x * static_cast<T_ACC>(mean[ng]) -
142 sum2 * static_cast<T_ACC>(rstd[ng]) * s;
143 }
144 }
145
146 template <typename T>
GammaBeta1dBackwardCUDAKernel1(int64_t N,int64_t C,int64_t group,const T * dY,const T * X,const T * mean,const T * rstd,T * dgamma,T * dbeta)147 __global__ void GammaBeta1dBackwardCUDAKernel1(
148 int64_t N,
149 int64_t C,
150 int64_t group,
151 const T* dY,
152 const T* X,
153 const T* mean,
154 const T* rstd,
155 T* dgamma,
156 T* dbeta) {
157 using T_ACC = acc_type<T, true>;
158 const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
159 if (c < C) {
160 const int64_t G = group;
161 const int64_t D = C / G;
162 T_ACC sum1 = 0;
163 T_ACC sum2 = 0;
164 for (int64_t n = 0; n < N; ++n) {
165 const int64_t nc = n * C + c;
166 const int64_t ng = n * G + c / D;
167 const T_ACC dy_acc = static_cast<T_ACC>(dY[nc]);
168 const T_ACC x_acc = static_cast<T_ACC>(X[nc]);
169 sum1 += (dgamma == nullptr)
170 ? T_ACC(0)
171 : ((dy_acc * x_acc - dy_acc * static_cast<T_ACC>(mean[ng])) *
172 static_cast<T_ACC>(rstd[ng]));
173 sum2 += (dbeta == nullptr) ? T_ACC(0) : dy_acc;
174 }
175 if (dgamma != nullptr) {
176 dgamma[c] = sum1;
177 }
178 if (dbeta != nullptr) {
179 dbeta[c] = sum2;
180 }
181 }
182 }
183
184 template <typename T>
GammaBeta1dBackwardCUDAKernel2(int64_t N,int64_t C,int64_t group,const T * dY,const T * X,const T * mean,const T * rstd,T * dgamma,T * dbeta)185 __global__ void GammaBeta1dBackwardCUDAKernel2(
186 int64_t N,
187 int64_t C,
188 int64_t group,
189 const T* dY,
190 const T* X,
191 const T* mean,
192 const T* rstd,
193 T* dgamma,
194 T* dbeta) {
195 using T_ACC = acc_type<T, true>;
196 __shared__ T_ACC g_shared[kReduceTileSize][kReduceTileSize + 1];
197 __shared__ T_ACC b_shared[kReduceTileSize][kReduceTileSize + 1];
198 const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
199 T_ACC dg_sum1 = 0;
200 T_ACC dg_sum2 = 0;
201 T_ACC db_sum1 = 0;
202 T_ACC db_sum2 = 0;
203 if (c < C) {
204 const int64_t G = group;
205 const int64_t D = C / G;
206 // Accumulate each 32 cols into a 32 * 32 tile.
207 // Since the blockDim is (32, 16), accumulate twice for 1st and 2nd 16 rows
208 // of a 32 contiguous elements.
209 for (int64_t n = threadIdx.y; n < N; n += blockDim.y * 2) {
210 const int64_t n1 = n;
211 const int64_t n2 = n + blockDim.y;
212 const int64_t nc1 = n1 * C + c;
213 const int64_t nc2 = n2 * C + c;
214 const int64_t ng1 = n1 * G + c / D;
215 const int64_t ng2 = n2 * G + c / D;
216 const T_ACC dy1_acc = static_cast<T_ACC>(dY[nc1]);
217 const T_ACC x1_acc = static_cast<T_ACC>(X[nc1]);
218 dg_sum1 += dgamma == nullptr
219 ? T_ACC(0)
220 : ((dy1_acc * x1_acc - dy1_acc * static_cast<T_ACC>(mean[ng1])) *
221 static_cast<T_ACC>(rstd[ng1]));
222 db_sum1 += dbeta == nullptr ? T_ACC(0) : dy1_acc;
223 if (n2 < N) {
224 const T_ACC dy2_acc = static_cast<T_ACC>(dY[nc2]);
225 const T_ACC x2_acc = static_cast<T_ACC>(X[nc2]);
226 dg_sum2 += dgamma == nullptr
227 ? T_ACC(0)
228 : ((dy2_acc * x2_acc - dy2_acc * static_cast<T_ACC>(mean[ng2])) *
229 static_cast<T_ACC>(rstd[ng2]));
230 db_sum2 += dbeta == nullptr ? T_ACC(0) : dy2_acc;
231 }
232 }
233 }
234
235 // Write accumulated tile to shared memory.
236 g_shared[threadIdx.y][threadIdx.x] = dg_sum1;
237 g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2;
238 b_shared[threadIdx.y][threadIdx.x] = db_sum1;
239 b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2;
240 __syncthreads();
241
242 // Do warp reduce for the 1st 16 cols in the tile.
243 T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y];
244 T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y];
245 sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
246 sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
247 if (threadIdx.x == 0) {
248 const int64_t c = blockIdx.x * blockDim.x + threadIdx.y;
249 if (c < C) {
250 if (dgamma != nullptr) {
251 dgamma[c] = sum1;
252 }
253 if (dbeta != nullptr) {
254 dbeta[c] = sum2;
255 }
256 }
257 }
258
259 // Do warp reduce for the 2nd 16 cols in the tile.
260 sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y];
261 sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y];
262 sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
263 sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
264 if (threadIdx.x == 0) {
265 const int64_t c = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y;
266 if (c < C) {
267 if (dgamma != nullptr) {
268 dgamma[c] = sum1;
269 }
270 if (dbeta != nullptr) {
271 dbeta[c] = sum2;
272 }
273 }
274 }
275 }
276
277 template <typename T>
ComputeInternalGradientsCUDAKernel(int64_t HxW,const T * dY,const T * X,acc_type<T,true> * ds,acc_type<T,true> * db)278 __global__ void ComputeInternalGradientsCUDAKernel(
279 int64_t HxW,
280 const T* dY,
281 const T* X,
282 acc_type<T, true>* ds,
283 acc_type<T, true>* db) {
284 using T_ACC = acc_type<T, true>;
285 const int64_t nc = blockIdx.x;
286 T_ACC sum1 = 0;
287 T_ACC sum2 = 0;
288 for (int64_t hw = threadIdx.x; hw < HxW; hw += blockDim.x) {
289 const int64_t index = nc * HxW + hw;
290 sum1 += static_cast<T_ACC>(dY[index]) * static_cast<T_ACC>(X[index]);
291 sum2 += static_cast<T_ACC>(dY[index]);
292 }
293 if (blockDim.x <= C10_WARP_SIZE) {
294 sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
295 sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
296 } else {
297 __shared__ T_ACC ds_shared[C10_WARP_SIZE];
298 __shared__ T_ACC db_shared[C10_WARP_SIZE];
299 sum1 = cuda_utils::BlockReduceSum<T_ACC>(sum1, ds_shared);
300 sum2 = cuda_utils::BlockReduceSum<T_ACC>(sum2, db_shared);
301 }
302 if (threadIdx.x == 0) {
303 ds[nc] = sum1;
304 db[nc] = sum2;
305 }
306 }
307
308 template <typename T>
ComputeBackwardFusedParamsCUDAKernel(int64_t C,int64_t HxW,int64_t group,const T * mean,const T * rstd,const T * gamma,const acc_type<T,true> * ds,const acc_type<T,true> * db,acc_type<T,true> * c2,acc_type<T,true> * c3)309 __global__ void ComputeBackwardFusedParamsCUDAKernel(
310 int64_t C,
311 int64_t HxW,
312 int64_t group,
313 const T* mean,
314 const T* rstd,
315 const T* gamma,
316 const acc_type<T, true>* ds,
317 const acc_type<T, true>* db,
318 acc_type<T, true>* c2,
319 acc_type<T, true>* c3) {
320 using T_ACC = acc_type<T, true>;
321 const int64_t G = group;
322 const int64_t D = C / G;
323 const int64_t n = blockIdx.x;
324 const int64_t g = blockIdx.y;
325 const int64_t ng = n * G + g;
326 T_ACC sum1 = 0;
327 T_ACC sum2 = 0;
328 for (int64_t i = threadIdx.x; i < D; i += blockDim.x) {
329 const int64_t index = ng * D + i;
330 const int64_t c = g * D + i;
331 const T_ACC gamma_v =
332 gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[c]);
333 sum1 += ds[index] * gamma_v;
334 sum2 += db[index] * gamma_v;
335 }
336 if (blockDim.x <= C10_WARP_SIZE) {
337 sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
338 sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
339 } else {
340 __shared__ T_ACC ds_shared[C10_WARP_SIZE];
341 __shared__ T_ACC db_shared[C10_WARP_SIZE];
342 sum1 = cuda_utils::BlockReduceSum<T_ACC>(sum1, ds_shared);
343 sum2 = cuda_utils::BlockReduceSum<T_ACC>(sum2, db_shared);
344 }
345 if (threadIdx.x == 0) {
346 const T_ACC s = T_ACC(1) / static_cast<T_ACC>(D * HxW);
347 const T_ACC x = (sum2 * static_cast<T_ACC>(mean[ng]) - sum1) *
348 static_cast<T_ACC>(rstd[ng]) * static_cast<T_ACC>(rstd[ng]) *
349 static_cast<T_ACC>(rstd[ng]) * s;
350 c2[ng] = x;
351 c3[ng] = -x * static_cast<T_ACC>(mean[ng]) -
352 sum2 * static_cast<T_ACC>(rstd[ng]) * s;
353 }
354 }
355
356 template <typename T>
GammaBetaBackwardCUDAKernel1(int64_t N,int64_t C,int64_t group,const T * mean,const T * rstd,const acc_type<T,true> * ds,const acc_type<T,true> * db,T * dgamma,T * dbeta)357 __global__ void GammaBetaBackwardCUDAKernel1(
358 int64_t N,
359 int64_t C,
360 int64_t group,
361 const T* mean,
362 const T* rstd,
363 const acc_type<T, true>* ds,
364 const acc_type<T, true>* db,
365 T* dgamma,
366 T* dbeta) {
367 using T_ACC = acc_type<T, true>;
368 const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
369 if (c < C) {
370 const int64_t G = group;
371 const int64_t D = C / G;
372 T_ACC sum1 = 0;
373 T_ACC sum2 = 0;
374 for (int64_t n = 0; n < N; ++n) {
375 const int64_t nc = n * C + c;
376 const int64_t ng = n * G + c / D;
377 sum1 += (dgamma == nullptr)
378 ? T_ACC(0)
379 : ((ds[nc] - db[nc] * static_cast<T_ACC>(mean[ng])) *
380 static_cast<T_ACC>(rstd[ng]));
381 sum2 += (dbeta == nullptr) ? T_ACC(0) : db[nc];
382 }
383 if (dgamma != nullptr) {
384 dgamma[c] = sum1;
385 }
386 if (dbeta != nullptr) {
387 dbeta[c] = sum2;
388 }
389 }
390 }
391
392 template <typename T>
GammaBetaBackwardCUDAKernel2(int64_t N,int64_t C,int64_t group,const T * mean,const T * rstd,const acc_type<T,true> * ds,const acc_type<T,true> * db,T * dgamma,T * dbeta)393 __global__ void GammaBetaBackwardCUDAKernel2(
394 int64_t N,
395 int64_t C,
396 int64_t group,
397 const T* mean,
398 const T* rstd,
399 const acc_type<T, true>* ds,
400 const acc_type<T, true>* db,
401 T* dgamma,
402 T* dbeta) {
403 using T_ACC = acc_type<T, true>;
404 __shared__ T_ACC g_shared[kReduceTileSize][kReduceTileSize + 1];
405 __shared__ T_ACC b_shared[kReduceTileSize][kReduceTileSize + 1];
406 const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
407 T_ACC dg_sum1 = 0;
408 T_ACC dg_sum2 = 0;
409 T_ACC db_sum1 = 0;
410 T_ACC db_sum2 = 0;
411 if (c < C) {
412 const int64_t G = group;
413 const int64_t D = C / G;
414 // Accumulate each 32 cols into a 32 * 32 tile.
415 // Since the blockDim is (32, 16), accumulate twice for 1st and 2nd 16 rows
416 // of a 32 contiguous elements.
417 for (int64_t n = threadIdx.y; n < N; n += blockDim.y * 2) {
418 const int64_t n1 = n;
419 const int64_t n2 = n + blockDim.y;
420 const int64_t nc1 = n1 * C + c;
421 const int64_t nc2 = n2 * C + c;
422 const int64_t ng1 = n1 * G + c / D;
423 const int64_t ng2 = n2 * G + c / D;
424 dg_sum1 += dgamma == nullptr
425 ? T_ACC(0)
426 : ((ds[nc1] - db[nc1] * static_cast<T_ACC>(mean[ng1])) *
427 static_cast<T_ACC>(rstd[ng1]));
428 db_sum1 += dbeta == nullptr ? T_ACC(0) : db[nc1];
429 if (n2 < N) {
430 dg_sum2 += dgamma == nullptr
431 ? T_ACC(0)
432 : ((ds[nc2] - db[nc2] * static_cast<T_ACC>(mean[ng2])) *
433 static_cast<T_ACC>(rstd[ng2]));
434 db_sum2 += dbeta == nullptr ? T_ACC(0) : db[nc2];
435 }
436 }
437 }
438
439 // Write accumulated tile to shared memory.
440 g_shared[threadIdx.y][threadIdx.x] = dg_sum1;
441 g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2;
442 b_shared[threadIdx.y][threadIdx.x] = db_sum1;
443 b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2;
444 __syncthreads();
445
446 // Do warp reduce for the 1st 16 cols in the tile.
447 T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y];
448 T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y];
449 sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
450 sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
451 if (threadIdx.x == 0) {
452 const int64_t c = blockIdx.x * blockDim.x + threadIdx.y;
453 if (c < C) {
454 if (dgamma != nullptr) {
455 dgamma[c] = sum1;
456 }
457 if (dbeta != nullptr) {
458 dbeta[c] = sum2;
459 }
460 }
461 }
462
463 // Do warp reduce for the 2st 16 cols in the tile.
464 sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y];
465 sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y];
466 sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
467 sum2 = cuda_utils::WarpReduceSum<T_ACC>(sum2);
468 if (threadIdx.x == 0) {
469 const int64_t c = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y;
470 if (c < C) {
471 if (dgamma != nullptr) {
472 dgamma[c] = sum1;
473 }
474 if (dbeta != nullptr) {
475 dbeta[c] = sum2;
476 }
477 }
478 }
479 }
480
481 template <typename T>
GroupNorm1dForward(const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t group,Tensor & Y)482 void GroupNorm1dForward(
483 const Tensor& X,
484 const Tensor& mean,
485 const Tensor& rstd,
486 const Tensor& gamma,
487 const Tensor& beta,
488 int64_t N,
489 int64_t C,
490 int64_t group,
491 Tensor& Y) {
492 using T_ACC = acc_type<T, true>;
493 const int64_t G = group;
494 const int64_t D = C / G;
495 if (gamma.defined() && beta.defined()) {
496 auto iter = TensorIteratorConfig()
497 .resize_outputs(false)
498 .add_owned_output(Y.view({N, G, D}))
499 .add_owned_const_input(X.view({N, G, D}))
500 .add_owned_input(mean.view({N, G, 1}))
501 .add_owned_input(rstd.view({N, G, 1}))
502 .add_owned_const_input(gamma.view({1, G, D}))
503 .add_owned_const_input(beta.view({1, G, D}))
504 .build();
505 gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma, T beta) -> T {
506 return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
507 static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma) +
508 static_cast<T_ACC>(beta);
509 });
510 } else if (gamma.defined()) {
511 auto iter = TensorIteratorConfig()
512 .resize_outputs(false)
513 .add_owned_output(Y.view({N, G, D}))
514 .add_owned_const_input(X.view({N, G, D}))
515 .add_owned_input(mean.view({N, G, 1}))
516 .add_owned_input(rstd.view({N, G, 1}))
517 .add_owned_const_input(gamma.view({1, G, D}))
518 .build();
519 gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma) -> T {
520 return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
521 static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma);
522 });
523 } else if (beta.defined()) {
524 auto iter = TensorIteratorConfig()
525 .resize_outputs(false)
526 .add_owned_output(Y.view({N, G, D}))
527 .add_owned_const_input(X.view({N, G, D}))
528 .add_owned_input(mean.view({N, G, 1}))
529 .add_owned_input(rstd.view({N, G, 1}))
530 .add_owned_const_input(beta.view({1, G, D}))
531 .build();
532 gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T beta) -> T {
533 return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
534 static_cast<T_ACC>(rstd) +
535 static_cast<T_ACC>(beta);
536 });
537 } else {
538 auto iter = TensorIteratorConfig()
539 .resize_outputs(false)
540 .add_owned_output(Y.view({N * G, D}))
541 .add_owned_const_input(X.view({N * G, D}))
542 .add_owned_input(mean.view({N * G, 1}))
543 .add_owned_input(rstd.view({N * G, 1}))
544 .build();
545 gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T {
546 return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
547 static_cast<T_ACC>(rstd);
548 });
549 }
550 AT_CUDA_CHECK(cudaGetLastError());
551 }
552
553 template <typename T>
GroupNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,T eps,Tensor & Y,Tensor & mean,Tensor & rstd)554 void GroupNormKernelImplInternal(
555 const Tensor& X,
556 const Tensor& gamma,
557 const Tensor& beta,
558 int64_t N,
559 int64_t C,
560 int64_t HxW,
561 int64_t group,
562 T eps,
563 Tensor& Y,
564 Tensor& mean,
565 Tensor& rstd) {
566 using T_ACC = acc_type<T, true>;
567 TORCH_CHECK(X.numel() == N * C * HxW);
568 TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
569 TORCH_CHECK(!beta.defined() || beta.numel() == C);
570 if (N == 0) {
571 return;
572 }
573 const int64_t G = group;
574 const int64_t D = C / G;
575 const T* X_data = X.const_data_ptr<T>();
576 T* mean_data = mean.mutable_data_ptr<T>();
577 T* rstd_data = rstd.mutable_data_ptr<T>();
578
579 cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
580 const int64_t num_threads = D * HxW < cuda_utils::kCUDABlockReduceNumThreads
581 ? at::cuda::warp_size()
582 : cuda_utils::kCUDABlockReduceNumThreads;
583 RowwiseMomentsCUDAKernel<T><<<N * G, num_threads, 0, cuda_stream>>>(
584 D * HxW, eps, X_data, mean_data, rstd_data);
585 C10_CUDA_KERNEL_LAUNCH_CHECK();
586
587 if (HxW == 1) {
588 GroupNorm1dForward<T>(X, mean, rstd, gamma, beta, N, C, G, Y);
589 } else if (!gamma.defined() && !beta.defined()) {
590 auto iter = TensorIteratorConfig()
591 .resize_outputs(false)
592 .add_owned_output(Y.view({N * G, D * HxW}))
593 .add_owned_const_input(X.view({N * G, D * HxW}))
594 .add_owned_input(mean.view({N * G, 1}))
595 .add_owned_input(rstd.view({N * G, 1}))
596 .build();
597 gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T {
598 return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
599 static_cast<T_ACC>(rstd);
600 });
601 } else {
602 const auto kAccType =
603 (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
604 ? kFloat
605 : X.scalar_type();
606 Tensor a = at::empty({N, C}, X.options().dtype(kAccType));
607 Tensor b = at::empty({N, C}, X.options().dtype(kAccType));
608 const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
609 const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
610 T_ACC* a_data = a.mutable_data_ptr<T_ACC>();
611 T_ACC* b_data = b.mutable_data_ptr<T_ACC>();
612
613 // TODO: Since there is some issues in gpu_kernel_multiple_outputs, we are
614 // using manual kernel here. Make it using gpu_kernel_multiple_outputs once
615 // the issue fixed.
616 const int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads;
617 ComputeFusedParamsCUDAKernel<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
618 N, C, G, mean_data, rstd_data, gamma_data, beta_data, a_data, b_data);
619 C10_CUDA_KERNEL_LAUNCH_CHECK();
620
621 auto iter = TensorIteratorConfig()
622 .check_all_same_dtype(std::is_same<T, T_ACC>::value)
623 .resize_outputs(false)
624 .add_owned_output(Y.view({N * C, HxW}))
625 .add_owned_const_input(X.view({N * C, HxW}))
626 .add_owned_input(a.view({N * C, 1}))
627 .add_owned_input(b.view({N * C, 1}))
628 .build();
629 gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
630 return a * static_cast<T_ACC>(x) + b;
631 });
632 }
633 AT_CUDA_CHECK(cudaGetLastError());
634 }
635
GroupNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)636 void GroupNormKernelImpl(
637 const Tensor& X,
638 const Tensor& gamma,
639 const Tensor& beta,
640 int64_t N,
641 int64_t C,
642 int64_t HxW,
643 int64_t group,
644 double eps,
645 Tensor& Y,
646 Tensor& mean,
647 Tensor& rstd) {
648 AT_DISPATCH_FLOATING_TYPES_AND2(
649 at::ScalarType::Half,
650 at::ScalarType::BFloat16,
651 X.scalar_type(),
652 "GroupNormKernelImpl",
653 [&]() {
654 GroupNormKernelImplInternal<scalar_t>(
655 X,
656 gamma,
657 beta,
658 N,
659 C,
660 HxW,
661 group,
662 static_cast<scalar_t>(eps),
663 Y,
664 mean,
665 rstd);
666 });
667 }
668
669 template <typename T>
GroupNorm1dBackward(const Tensor dY,const Tensor X,const Tensor mean,const Tensor rstd,const Tensor gamma,int64_t N,int64_t C,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)670 void GroupNorm1dBackward(
671 const Tensor dY,
672 const Tensor X,
673 const Tensor mean,
674 const Tensor rstd,
675 const Tensor gamma,
676 int64_t N,
677 int64_t C,
678 int64_t group,
679 Tensor& dX,
680 Tensor& dgamma,
681 Tensor& dbeta) {
682 using T_ACC = acc_type<T, true>;
683 const int64_t G = group;
684 const int64_t D = C / G;
685 const T* dY_data = dY.const_data_ptr<T>();
686 const T* X_data = X.const_data_ptr<T>();
687 const T* mean_data = mean.const_data_ptr<T>();
688 const T* rstd_data = rstd.const_data_ptr<T>();
689
690 cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
691 if (dX.defined()) {
692 const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
693 const auto kAccType =
694 (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
695 ? kFloat
696 : X.scalar_type();
697 Tensor c2 = at::empty({N, G}, X.options().dtype(kAccType));
698 Tensor c3 = at::empty({N, G}, X.options().dtype(kAccType));
699 T_ACC* c2_data = c2.mutable_data_ptr<T_ACC>();
700 T_ACC* c3_data = c3.mutable_data_ptr<T_ACC>();
701 const int64_t num_threads = (C / G) < cuda_utils::kCUDABlockReduceNumThreads
702 ? at::cuda::warp_size()
703 : cuda_utils::kCUDABlockReduceNumThreads;
704 Compute1dBackwardFusedParamsCUDAKernel<T>
705 <<<dim3(N, G), num_threads, 0, cuda_stream>>>(
706 C,
707 G,
708 dY_data,
709 X_data,
710 mean_data,
711 rstd_data,
712 gamma_data,
713 c2_data,
714 c3_data);
715 C10_CUDA_KERNEL_LAUNCH_CHECK();
716
717 if (gamma.defined()) {
718 auto iter = TensorIteratorConfig()
719 .check_all_same_dtype(std::is_same<T, T_ACC>::value)
720 .resize_outputs(false)
721 .add_owned_output(dX.view({N, G, D}))
722 .add_owned_const_input(dY.view({N, G, D}))
723 .add_owned_const_input(X.view({N, G, D}))
724 .add_owned_const_input(rstd.view({N, G, 1}))
725 .add_owned_const_input(gamma.view({1, G, D}))
726 .add_owned_const_input(c2.view({N, G, 1}))
727 .add_owned_const_input(c3.view({N, G, 1}))
728 .build();
729 gpu_kernel(
730 iter,
731 [] GPU_LAMBDA(T dy, T x, T rstd, T gamma, T_ACC c2, T_ACC c3) -> T {
732 const T_ACC c1 =
733 static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma);
734 return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
735 c3;
736 });
737 } else {
738 auto iter = TensorIteratorConfig()
739 .check_all_same_dtype(std::is_same<T, T_ACC>::value)
740 .resize_outputs(false)
741 .add_owned_output(dX.view({N * G, D}))
742 .add_owned_const_input(dY.view({N * G, D}))
743 .add_owned_const_input(X.view({N * G, D}))
744 .add_owned_const_input(rstd.view({N * G, 1}))
745 .add_owned_const_input(c2.view({N * G, 1}))
746 .add_owned_const_input(c3.view({N * G, 1}))
747 .build();
748 gpu_kernel(
749 iter, [] GPU_LAMBDA(T dy, T x, T rstd, T_ACC c2, T_ACC c3) -> T {
750 const T_ACC c1 = static_cast<T_ACC>(rstd);
751 return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
752 c3;
753 });
754 }
755 }
756 if (dgamma.defined() || dbeta.defined()) {
757 T* dgamma_data = dgamma.defined() ? dgamma.mutable_data_ptr<T>() : nullptr;
758 T* dbeta_data = dbeta.defined() ? dbeta.mutable_data_ptr<T>() : nullptr;
759 if (N <= 128) {
760 const int64_t B = (C + kCUDANumThreads - 1) / kCUDANumThreads;
761 GammaBeta1dBackwardCUDAKernel1<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
762 N,
763 C,
764 G,
765 dY_data,
766 X_data,
767 mean_data,
768 rstd_data,
769 dgamma_data,
770 dbeta_data);
771 C10_CUDA_KERNEL_LAUNCH_CHECK();
772 } else {
773 const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize;
774 // The algorithm for colwise reduction here is to accumulate each 32 cols
775 // to a 32 * 32 tile and write the tile to shared memory. Then do warp
776 // reduce for each col in the tile. So here the blockDim must be (32, 16).
777 constexpr int kThreadX = kReduceTileSize;
778 constexpr int kThreadY = kReduceTileSize / 2;
779 GammaBeta1dBackwardCUDAKernel2<T>
780 <<<B, dim3(kThreadX, kThreadY), 0, cuda_stream>>>(
781 N,
782 C,
783 G,
784 dY_data,
785 X_data,
786 mean_data,
787 rstd_data,
788 dgamma_data,
789 dbeta_data);
790 C10_CUDA_KERNEL_LAUNCH_CHECK();
791 }
792 }
793 }
794
795 template <typename T>
GroupNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)796 void GroupNormBackwardKernelImplInternal(
797 const Tensor& dY,
798 const Tensor& X,
799 const Tensor& mean,
800 const Tensor& rstd,
801 const Tensor& gamma,
802 int64_t N,
803 int64_t C,
804 int64_t HxW,
805 int64_t group,
806 Tensor& dX,
807 Tensor& dgamma,
808 Tensor& dbeta) {
809 using T_ACC = acc_type<T, true>;
810 const int64_t G = group;
811 const int64_t D = C / G;
812 TORCH_CHECK(dY.numel() == N * C * HxW);
813 TORCH_CHECK(X.numel() == N * C * HxW);
814 TORCH_CHECK(mean.numel() == N * G);
815 TORCH_CHECK(rstd.numel() == N * G);
816 TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
817 cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
818
819 if (N == 0) {
820 if (dgamma.defined()) {
821 dgamma.fill_(T(0));
822 }
823 if (dbeta.defined()) {
824 dbeta.fill_(T(0));
825 }
826 return;
827 }
828
829 const T* dY_data = dY.const_data_ptr<T>();
830 const T* X_data = X.const_data_ptr<T>();
831 const T* mean_data = mean.const_data_ptr<T>();
832 const T* rstd_data = rstd.const_data_ptr<T>();
833 const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
834 const auto kAccType =
835 (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
836 ? kFloat
837 : X.scalar_type();
838 Tensor ds = at::empty({N, C}, X.options().dtype(kAccType));
839 Tensor db = at::empty({N, C}, X.options().dtype(kAccType));
840 T_ACC* ds_data = ds.mutable_data_ptr<T_ACC>();
841 T_ACC* db_data = db.mutable_data_ptr<T_ACC>();
842
843 if (HxW == 1) {
844 GroupNorm1dBackward<T>(
845 dY, X, mean, rstd, gamma, N, C, G, dX, dgamma, dbeta);
846 return;
847 }
848
849 int warp_size = at::cuda::warp_size();
850 int64_t num_threads = HxW < cuda_utils::kCUDABlockReduceNumThreads
851 ? warp_size
852 : cuda_utils::kCUDABlockReduceNumThreads;
853 ComputeInternalGradientsCUDAKernel<T><<<N * C, num_threads, 0, cuda_stream>>>(
854 HxW, dY_data, X_data, ds_data, db_data);
855 C10_CUDA_KERNEL_LAUNCH_CHECK();
856
857 if (dX.defined()) {
858 Tensor c1 = at::empty({0}, X.options().dtype(kAccType));
859 Tensor c2 = at::empty({N, G}, X.options().dtype(kAccType));
860 Tensor c3 = at::empty({N, G}, X.options().dtype(kAccType));
861 T_ACC* c2_data = c2.mutable_data_ptr<T_ACC>();
862 T_ACC* c3_data = c3.mutable_data_ptr<T_ACC>();
863
864 if (gamma.defined()) {
865 auto iter = TensorIteratorConfig()
866 .check_all_same_dtype(std::is_same<T, T_ACC>::value)
867 .add_output(c1)
868 .add_owned_const_input(rstd.view({N, G, 1}))
869 .add_owned_const_input(gamma.view({1, G, D}))
870 .build();
871 gpu_kernel(iter, [] GPU_LAMBDA(T rstd, T gamma) -> T_ACC {
872 return static_cast<T_ACC>(rstd) * static_cast<T_ACC>(gamma);
873 });
874 }
875
876 num_threads = (C / G) < cuda_utils::kCUDABlockReduceNumThreads
877 ? warp_size
878 : cuda_utils::kCUDABlockReduceNumThreads;
879 ComputeBackwardFusedParamsCUDAKernel<T>
880 <<<dim3(N, G), num_threads, 0, cuda_stream>>>(
881 C,
882 HxW,
883 G,
884 mean_data,
885 rstd_data,
886 gamma_data,
887 ds_data,
888 db_data,
889 c2_data,
890 c3_data);
891 C10_CUDA_KERNEL_LAUNCH_CHECK();
892
893 if (gamma.defined()) {
894 auto iter = TensorIteratorConfig()
895 .check_all_same_dtype(std::is_same<T, T_ACC>::value)
896 .resize_outputs(false)
897 .add_owned_output(dX.view({N * G, D, HxW}))
898 .add_owned_const_input(dY.view({N * G, D, HxW}))
899 .add_owned_const_input(X.view({N * G, D, HxW}))
900 .add_owned_const_input(c1.view({N * G, D, 1}))
901 .add_owned_const_input(c2.view({N * G, 1, 1}))
902 .add_owned_const_input(c3.view({N * G, 1, 1}))
903 .build();
904 gpu_kernel(
905 iter, [] GPU_LAMBDA(T dy, T x, T_ACC c1, T_ACC c2, T_ACC c3) -> T {
906 return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
907 c3;
908 });
909 } else {
910 auto iter = TensorIteratorConfig()
911 .check_all_same_dtype(std::is_same<T, T_ACC>::value)
912 .resize_outputs(false)
913 .add_owned_output(dX.view({N * G, D * HxW}))
914 .add_owned_const_input(dY.view({N * G, D * HxW}))
915 .add_owned_const_input(X.view({N * G, D * HxW}))
916 .add_owned_const_input(rstd.view({N * G, 1}))
917 .add_owned_const_input(c2.view({N * G, 1}))
918 .add_owned_const_input(c3.view({N * G, 1}))
919 .build();
920 gpu_kernel(
921 iter, [] GPU_LAMBDA(T dy, T x, T_ACC c1, T_ACC c2, T_ACC c3) -> T {
922 return c1 * static_cast<T_ACC>(dy) + c2 * static_cast<T_ACC>(x) +
923 c3;
924 });
925 }
926 }
927 if (dgamma.defined() || dbeta.defined()) {
928 T* dgamma_data = dgamma.defined() ? dgamma.mutable_data_ptr<T>() : nullptr;
929 T* dbeta_data = dbeta.defined() ? dbeta.mutable_data_ptr<T>() : nullptr;
930 if (N <= 128) {
931 // For small batch size, do colwise reduce directly.
932 const int64_t B = (C + kCUDANumThreads - 1) / kCUDANumThreads;
933 GammaBetaBackwardCUDAKernel1<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
934 N,
935 C,
936 G,
937 mean_data,
938 rstd_data,
939 ds_data,
940 db_data,
941 dgamma_data,
942 dbeta_data);
943 C10_CUDA_KERNEL_LAUNCH_CHECK();
944 } else {
945 const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize;
946 // The algorithm for colwise reduction here is to accumulate each 32 cols
947 // to a 32 * 32 tile and write the tile to shared memory. Then do warp
948 // reduce for each col in the tile. So here the blockDim must be (32, 16).
949 constexpr int kThreadX = kReduceTileSize;
950 constexpr int kThreadY = kReduceTileSize / 2;
951 GammaBetaBackwardCUDAKernel2<T>
952 <<<B, dim3(kThreadX, kThreadY), 0, cuda_stream>>>(
953 N,
954 C,
955 G,
956 mean_data,
957 rstd_data,
958 ds_data,
959 db_data,
960 dgamma_data,
961 dbeta_data);
962 C10_CUDA_KERNEL_LAUNCH_CHECK();
963 }
964 }
965 }
966
GroupNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)967 void GroupNormBackwardKernelImpl(
968 const Tensor& dY,
969 const Tensor& X,
970 const Tensor& mean,
971 const Tensor& rstd,
972 const Tensor& gamma,
973 int64_t N,
974 int64_t C,
975 int64_t HxW,
976 int64_t group,
977 Tensor& dX,
978 Tensor& dgamma,
979 Tensor& dbeta) {
980 AT_DISPATCH_FLOATING_TYPES_AND2(
981 at::ScalarType::Half,
982 at::ScalarType::BFloat16,
983 X.scalar_type(),
984 "GroupNormBackwardKernelImpl",
985 [&]() {
986 GroupNormBackwardKernelImplInternal<scalar_t>(
987 dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
988 });
989 }
990
991 } // namespace
992
993 REGISTER_DISPATCH(GroupNormKernel, &GroupNormKernelImpl);
994 REGISTER_DISPATCH(GroupNormBackwardKernel, &GroupNormBackwardKernelImpl);
995
996 } // namespace at::native
997