xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/group_norm_kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/group_norm.h>
3 
4 #include <algorithm>
5 #include <array>
6 #include <numeric>
7 
8 #include <ATen/core/Tensor.h>
9 #include <ATen/Dispatch.h>
10 #include <ATen/cpu/vec/vec.h>
11 #include <ATen/cpu/vec/functional.h>
12 #include <ATen/native/cpu/utils.h>
13 #include <ATen/native/cpu/moments_utils.h>
14 #include <ATen/native/cpu/mixed_data_type.h>
15 #include <ATen/OpMathType.h>
16 #include <c10/util/irange.h>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #endif
23 
24 namespace at::native {
25 
26 namespace {
27 
28 template <typename T, typename PT>
GroupNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)29 void GroupNormKernelImplInternal(
30     const Tensor& X,
31     const Tensor& gamma,
32     const Tensor& beta,
33     int64_t N,
34     int64_t C,
35     int64_t HxW,
36     int64_t group,
37     double eps,
38     Tensor& Y,
39     Tensor& mean,
40     Tensor& rstd) {
41   TORCH_CHECK(X.numel() == N * C * HxW);
42   TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
43   TORCH_CHECK(!beta.defined() || beta.numel() == C);
44   const int64_t G = group;
45   const int64_t D = C / G;
46   const T* X_data = X.const_data_ptr<T>();
47   const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
48   const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
49   T* Y_data = Y.data_ptr<T>();
50   PT* mean_data = mean.data_ptr<PT>();
51   PT* rstd_data = rstd.data_ptr<PT>();
52   const bool gamma_null = (gamma_data == nullptr);
53   const bool beta_null = beta_data == nullptr;
54   const int64_t inner_size = D * HxW;
55 
56   using opmath_t = at::opmath_type<T>;
57 
58   at::parallel_for(0, N * G, 1, [&](int64_t start, int64_t end) {
59     for (const auto i : c10::irange(start, end)) {
60       const T* X_ptr = X_data + i * inner_size;
61       auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, inner_size);
62       rstd_val = opmath_t(1) / std::sqrt(std::max(rstd_val, opmath_t(0)) + eps);
63       if (gamma_null && beta_null) {
64         T* Y_ptr = Y_data + i * inner_size;
65         for (const auto j : c10::irange(inner_size)) {
66           Y_ptr[j] = (X_ptr[j] - mean_val) * rstd_val;
67         }
68       } else {
69         const int64_t g = i % G;
70         for (const auto j : c10::irange(D)) {
71           const int64_t c = g * D + j;
72           const opmath_t scale = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
73           const opmath_t bias = -scale * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
74           X_ptr = X_data + (i * D + j) * HxW;
75           T* Y_ptr = Y_data + (i * D + j) * HxW;
76           for (const auto k : c10::irange(HxW)) {
77             Y_ptr[k] = scale * X_ptr[k] + bias;
78           }
79         }
80       }
81       mean_data[i] = mean_val;
82       rstd_data[i] = rstd_val;
83     }
84   });
85 }
86 
87 template <typename T>
88 typename std::enable_if<std::is_same<T, at::opmath_type<T>>::value,
89   std::tuple<T, T>>::type
ColumnwiseMoments(const T * X_data,int64_t HxW,int64_t C,int64_t D)90 ColumnwiseMoments(
91     const T* X_data,
92     int64_t HxW,
93     int64_t C,
94     int64_t D) {
95   using Vec = vec::Vectorized<T>;
96   constexpr int64_t K = Vec::size();
97   const int64_t inner_size = D / K * K;
98   Vec acc0_vec{0}, acc1_vec{0};
99   for (const auto m : c10::irange(HxW)) {
100     const T* X_ptr = X_data + m * C;
101     int64_t d = 0;
102     for (; d < inner_size; d += K) {
103       Vec x_vec = Vec::loadu(X_ptr + d);
104       acc0_vec += x_vec;
105       acc1_vec += x_vec * x_vec;
106     }
107     if (D - d > 0) {
108       Vec x_vec = Vec::loadu(X_ptr + d, D - d);
109       acc0_vec += x_vec;
110       acc1_vec += x_vec * x_vec;
111     }
112   }
113   T mean_val = vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; }, acc0_vec);
114   T rstd_val = vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; }, acc1_vec);
115   return std::tuple<T, T>(mean_val, rstd_val);
116 }
117 
118 
119 // std::is_same<T, at::BFloat16> || std::is_same<T, at::Half>
120 template <typename T>
121 typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value,
122   std::tuple<at::opmath_type<T>, at::opmath_type<T>>>::type
ColumnwiseMoments(const T * X_data,int64_t HxW,int64_t C,int64_t D)123 ColumnwiseMoments(
124     const T* X_data,
125     int64_t HxW,
126     int64_t C,
127     int64_t D) {
128   using opmath_t = at::opmath_type<T>;
129   using Vec = vec::Vectorized<T>;
130   using fVec = vec::Vectorized<opmath_t>;
131   constexpr int64_t K = Vec::size();
132   const int64_t inner_size = D / K * K;
133   fVec acc0_fvec{0}, acc1_fvec{0}, zero{0};
134   for (const auto m : c10::irange(HxW)) {
135     const T* X_ptr = X_data + m * C;
136     int64_t d = 0;
137     for (; d < inner_size; d += K) {
138       Vec x_bvec = Vec::loadu(X_ptr + d);
139       auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
140       acc0_fvec += x_fvec0 + x_fvec1;
141       acc1_fvec += x_fvec0 * x_fvec0 + x_fvec1 * x_fvec1;
142     }
143     if (D - d > 0) {
144       Vec x_bvec = Vec::loadu(X_ptr + d, D - d);
145       auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
146       if (D - d > fVec::size()) {
147         x_fvec1 = fVec::set(zero, x_fvec1, D - d - fVec::size());
148         acc0_fvec += x_fvec0 + x_fvec1;
149         acc1_fvec += x_fvec0 * x_fvec0 + x_fvec1 * x_fvec1;
150       } else {
151         x_fvec0 = fVec::set(zero, x_fvec0, D - d);
152         acc0_fvec += x_fvec0;
153         acc1_fvec += x_fvec0 * x_fvec0;
154       }
155     }
156   }
157   opmath_t mean_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, acc0_fvec);
158   opmath_t rstd_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, acc1_fvec);
159   return std::tuple<opmath_t, opmath_t>(mean_val, rstd_val);
160 }
161 
162 template <typename T, typename opmath_t>
163 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
CalcMeanVar(const T * X_ptr,opmath_t * mean_ptr,opmath_t * rstd_ptr,int64_t C)164 CalcMeanVar(
165   const T* X_ptr,
166   opmath_t* mean_ptr,
167   opmath_t* rstd_ptr,
168   int64_t C) {
169   using Vec = vec::Vectorized<T>;
170   vec::map2<T>(
171           [](Vec x, Vec y) { return x + y; },
172           mean_ptr,
173           X_ptr,
174           mean_ptr,
175           C);
176   vec::map2<T>(
177       [](Vec x, Vec y) { return x * x + y; },
178       rstd_ptr,
179       X_ptr,
180       rstd_ptr,
181       C);
182 }
183 
184 // std::is_same<T, at::BFloat16> || std::is_same<T, at::Half>
185 template <typename T, typename opmath_t>
186 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
CalcMeanVar(const T * X_ptr,opmath_t * mean_ptr,opmath_t * rstd_ptr,int64_t C)187 CalcMeanVar(
188   const T* X_ptr,
189   opmath_t* mean_ptr,
190   opmath_t* rstd_ptr,
191   int64_t C) {
192   using fVec = vec::Vectorized<opmath_t>;
193   using Vec = vec::Vectorized<T>;
194   int64_t d = 0;
195   for (; d < C - (C % Vec::size()); d += Vec::size()) {
196     Vec data_bvec = Vec::loadu(X_ptr + d);
197     fVec mean_fvec0 = fVec::loadu(mean_ptr + d);
198     fVec mean_fvec1 = fVec::loadu(mean_ptr + d + fVec::size());
199     fVec rstd_fvec0 = fVec::loadu(rstd_ptr + d);
200     fVec rstd_fvec1 = fVec::loadu(rstd_ptr + d + fVec::size());
201     auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
202     mean_fvec0 = data_fvec0 + mean_fvec0;
203     mean_fvec1 = data_fvec1 + mean_fvec1;
204     rstd_fvec0 = data_fvec0 * data_fvec0 + rstd_fvec0;
205     rstd_fvec1 = data_fvec1 * data_fvec1 + rstd_fvec1;
206     mean_fvec0.store(mean_ptr + d);
207     mean_fvec1.store(mean_ptr + d + fVec::size());
208     rstd_fvec0.store(rstd_ptr + d);
209     rstd_fvec1.store(rstd_ptr + d + fVec::size());
210   }
211   if (C - d > 0) {
212     Vec data_bvec = Vec::loadu(X_ptr + d, C - d);
213     fVec mean_fvec0 = fVec::loadu(mean_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
214     fVec mean_fvec1 = fVec::loadu(mean_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
215     fVec rstd_fvec0 = fVec::loadu(rstd_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
216     fVec rstd_fvec1 = fVec::loadu(rstd_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
217     auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
218     mean_fvec0 = data_fvec0 + mean_fvec0;
219     mean_fvec1 = data_fvec1 + mean_fvec1;
220     rstd_fvec0 = data_fvec0 * data_fvec0 + rstd_fvec0;
221     rstd_fvec1 = data_fvec1 * data_fvec1 + rstd_fvec1;
222     mean_fvec0.store(mean_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
223     mean_fvec1.store(mean_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
224     rstd_fvec0.store(rstd_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
225     rstd_fvec1.store(rstd_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
226   }
227 }
228 
229 template <typename T, typename opmath_t>
230 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ApplyScaleBias(T * Y_ptr,const T * X_ptr,const opmath_t * scale_ptr,const opmath_t * bias_ptr,int64_t C)231 ApplyScaleBias(
232   T* Y_ptr,
233   const T* X_ptr,
234   const opmath_t* scale_ptr,
235   const opmath_t* bias_ptr,
236   int64_t C) {
237   using Vec = vec::Vectorized<T>;
238   vec::map3<T>(
239     [](Vec x, Vec scale, Vec bias) { return x * scale + bias; },
240     Y_ptr,
241     X_ptr,
242     scale_ptr,
243     bias_ptr,
244     C);
245 }
246 
247 // std::is_same<T, at::BFloat16> || std::is_same<T, at::Half>
248 template <typename T, typename opmath_t>
249 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ApplyScaleBias(T * Y_ptr,const T * X_ptr,const opmath_t * scale_ptr,const opmath_t * bias_ptr,int64_t C)250 ApplyScaleBias(
251   T* Y_ptr,
252   const T* X_ptr,
253   const opmath_t* scale_ptr,
254   const opmath_t* bias_ptr,
255   int64_t C) {
256   using fVec = vec::Vectorized<opmath_t>;
257   using Vec = vec::Vectorized<T>;
258   int64_t d = 0;
259   for (; d < C - (C % Vec::size()); d += Vec::size()) {
260     Vec data_bvec = Vec::loadu(X_ptr + d);
261     fVec scale_fvec0 = fVec::loadu(scale_ptr + d);
262     fVec scale_fvec1 = fVec::loadu(scale_ptr + d + fVec::size());
263     fVec bias_fvec0 = fVec::loadu(bias_ptr + d);
264     fVec bias_fvec1 = fVec::loadu(bias_ptr + d + fVec::size());
265     auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
266     fVec out0 = data_fvec0 * scale_fvec0 + bias_fvec0;
267     fVec out1 = data_fvec1 * scale_fvec1 + bias_fvec1;
268     convert_from_float<T>(out0, out1).store(Y_ptr + d);
269   }
270   if (C - d > 0) {
271     Vec data_bvec = Vec::loadu(X_ptr + d, C - d);
272     fVec scale_fvec0 = fVec::loadu(scale_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
273     fVec scale_fvec1 = fVec::loadu(scale_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
274     fVec bias_fvec0 = fVec::loadu(bias_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
275     fVec bias_fvec1 = fVec::loadu(bias_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
276     auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
277     fVec out0 = data_fvec0 * scale_fvec0 + bias_fvec0;
278     fVec out1 = data_fvec1 * scale_fvec1 + bias_fvec1;
279     convert_from_float<T>(out0, out1).store(Y_ptr + d, C - d);
280   }
281 }
282 
283 template <typename T, typename PT>
GroupNormKernelImplChannelsLastInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)284 void GroupNormKernelImplChannelsLastInternal(
285     const Tensor& X,
286     const Tensor& gamma,
287     const Tensor& beta,
288     int64_t N,
289     int64_t C,
290     int64_t HxW,
291     int64_t group,
292     double eps,
293     Tensor& Y,
294     Tensor& mean,
295     Tensor& rstd) {
296   TORCH_CHECK(X.numel() == N * C * HxW);
297   TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
298   TORCH_CHECK(!beta.defined() || beta.numel() == C);
299   const int64_t G = group;
300   const int64_t D = C / G;
301   const T* X_data = X.const_data_ptr<T>();
302   const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
303   const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
304   T* Y_data = Y.data_ptr<T>();
305   PT* mean_data = mean.data_ptr<PT>();
306   PT* rstd_data = rstd.data_ptr<PT>();
307 
308   using opmath_t = at::opmath_type<T>;
309 
310   const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
311   const bool gamma_null = (gamma_data == nullptr);
312   const bool beta_null = beta_data == nullptr;
313 
314   // NB: About algorithm choosen:
315   //
316   // On channels last, GroupNorm has a input shape of {N, H, W, GD},
317   // Mean and rstd are collected per each n and g, which involves reduction
318   // on non-adjacent dimensions. We can parallel in the following 2 impls:
319   //
320   // impl-1: parallel on N * G. Only need one omp session but memory access
321   //   per thread is non-contiguous.
322   //
323   // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
324   //   but requires help of extra temp buffer of size {T, N, 2C}.
325   //
326   // Generally impl-2 has better performance when HxW is large enough, so that
327   //   data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
328   //
329   constexpr int64_t feature_map_threshold = 1024;
330   if (HxW < feature_map_threshold) {
331     // impl-1: parallel on N * G.
332     //
333     // for each plain of HxW, scale and bias is calculated only once
334     Tensor buffer = at::empty({N * G, 2 * D}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
335     opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
336 
337     at::parallel_for(0, N * G, 1, [&](int64_t begin, int64_t end) {
338       int64_t n{0}, g{0};
339       data_index_init(begin, n, N, g, G);
340       for (const auto i : c10::irange(begin, end)) {
341         // step-1: for each n and g, collect sum of x and x2
342         //
343         // Note that using vec::map_reduce_all here is simpler to write
344         // but it is slower since horizontal reduce from vec to scalar is slow.
345         // So it is better to reduce with a vec across all HxW plain,
346         // and do a horizontal add just once for each {n, g}.
347         //
348         auto [mean_val, rstd_val] = ColumnwiseMoments(
349                 X_data + n * HxW * C + g * D,
350                 HxW,
351                 C,
352                 D);
353 
354         mean_val *= s;
355         rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
356         rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
357         mean_data[i] = mean_val;
358         rstd_data[i] = rstd_val;
359 
360         // step-2: calculate scale and bias
361         opmath_t* scale_ptr = buffer_data + i * 2 * D;
362         opmath_t* bias_ptr = scale_ptr + D;
363         for (const auto d : c10::irange(D)) {
364           const int64_t c = g * D + d;
365           scale_ptr[d] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
366           bias_ptr[d] = -scale_ptr[d] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
367         }
368 
369         // step-3: apply scale and bias
370         for (const auto m : c10::irange(HxW)) {
371           const T* X_ptr = X_data + n * HxW * C + m * C + g * D;
372           T* Y_ptr = Y_data + n * HxW * C + m * C + g * D;
373           ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, D);
374         }
375 
376         data_index_step(n, N, g, G);
377       }
378     });
379   } else {
380     // impl-2: parallel on N * HxW.
381     //
382     // temp buffer holding x and x2
383     int num_threads = at::get_num_threads();
384     Tensor buffer = at::empty({num_threads, N, 2 * C},
385       X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
386     opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
387     Tensor tmp_buffer = at::empty({N, 2 * G},
388       X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
389     opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();
390     // step-1: accumulate on dimension of C
391     //
392     // In order to improve multi-core performance when N=1,
393     // we parallel on the all the outer dimensions of N and HxW,
394     // leaving the most inner dimension C for vectorization.
395     //
396     // Note that parallel on {N, HxW, G} is not feasible for some common configs,
397     // e.g. say input shape is {1, 32, h, w} and G = 8,
398     //   this will give D = 4 which is unable to take full SIMD length.
399     //
400     // To avoid thread conflict, we make use of a temp buffer of {T, N, 2C},
401     //   firstly, reduce from {N, HxW, C} to {T, N, 2C}
402     //
403     at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
404       int tid = at::get_thread_num();
405       opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;
406 
407       int64_t n{0}, m{0};
408       data_index_init(begin, n, N, m, HxW);
409       for (const auto i : c10::irange(begin, end)) {
410         opmath_t* mean_ptr = buffer_ptr + n * 2 * C;
411         opmath_t* rstd_ptr = mean_ptr + C;
412         const T* X_ptr = X_data + i * C;
413         CalcMeanVar<T, opmath_t>(X_ptr, mean_ptr, rstd_ptr, C);
414         data_index_step(n, N, m, HxW);
415       }
416     });
417 
418     // step-2: compute mean and rstd
419     for (const auto n : c10::irange(N)) {
420       for (const auto g : c10::irange(G)) {
421         opmath_t mean_val{0}, rstd_val{0};
422         for (const auto d : c10::irange(D)) {
423           for (const auto t : c10::irange(num_threads)) {
424             opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
425             mean_val += buffer_ptr[g * D + d];
426             rstd_val += buffer_ptr[g * D + d + C];
427            }
428         }
429         mean_val *= s;
430         rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
431         rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
432         tmp_buffer_data[n * 2 * G + 2 * g] = mean_val;
433         tmp_buffer_data[n * 2 * G + 2 * g + 1] = rstd_val;
434       }
435     }
436 
437     // step-3: compute scale and bias
438     //
439     // mean/rstd have shape of {N, G}, gamma/beta have shape of {G, D}.
440     // And scale/bias have shape of {N, C} so that we can directly vectorize on
441     // dimension of C in the final step.
442     //
443     // We could fuse step 3 and 4 into a single session but this way is better:
444     //   a. D might be too small for vectorization;
445     //   b. Avoid duplicate calculation of scale/bias, each HxW plain share the same scale/bias
446     //
447     for (const auto n : c10::irange(N)) {
448       for (const auto g : c10::irange(G)) {
449         opmath_t* scale_ptr = buffer_data + n * 2 * C;
450         opmath_t* bias_ptr = scale_ptr + C;
451         opmath_t mean_val = tmp_buffer_data[n * 2 * G + 2 * g];
452         opmath_t rstd_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];
453         mean_data[n * G + g] = mean_val;
454         rstd_data[n * G + g] = rstd_val;
455 
456         for (const auto d : c10::irange(D)) {
457           const int64_t c = g * D + d;
458           scale_ptr[c] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
459           bias_ptr[c] = -scale_ptr[c] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
460         }
461       }
462     }
463 
464     // step-4: apply scale and bias
465     //
466     // Parallel on on the all the outer dimensions of N and HxW
467     // and vectorize on C.
468     //
469     at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
470       int64_t n{0}, m{0};
471       data_index_init(begin, n, N, m, HxW);
472       for (const auto i : c10::irange(begin, end)) {
473         const T* X_ptr = X_data + i * C;
474         T* Y_ptr = Y_data + i * C;
475         opmath_t* scale_ptr = buffer_data + n * 2 * C;
476         opmath_t* bias_ptr = scale_ptr + C;
477         ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, C);
478         data_index_step(n, N, m, HxW);
479       }
480     });
481   }
482 }
483 
GroupNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)484 void GroupNormKernelImpl(
485     const Tensor& X,
486     const Tensor& gamma,
487     const Tensor& beta,
488     int64_t N,
489     int64_t C,
490     int64_t HxW,
491     int64_t group,
492     double eps,
493     Tensor& Y,
494     Tensor& mean,
495     Tensor& rstd) {
496   const bool mixed_type = is_mixed_type(X, gamma, beta);
497   switch (X.suggest_memory_format()) {
498     case at::MemoryFormat::Contiguous: {
499       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormKernelImpl", [&]() {
500         using param_t = at::opmath_type<scalar_t>;
501         if (mixed_type) {
502           GroupNormKernelImplInternal<scalar_t, param_t>(
503               X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
504         } else {
505           GroupNormKernelImplInternal<scalar_t, scalar_t>(
506               X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
507         }
508       });
509       break;
510     }
511     case at::MemoryFormat::ChannelsLast:
512     case at::MemoryFormat::ChannelsLast3d: {
513       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormKernelImpl", [&]() {
514         using param_t = at::opmath_type<scalar_t>;
515         if (mixed_type) {
516           GroupNormKernelImplChannelsLastInternal<scalar_t, param_t>(
517               X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
518         } else {
519           GroupNormKernelImplChannelsLastInternal<scalar_t, scalar_t>(
520               X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
521         }
522       });
523       break;
524     }
525     default:
526       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, ChannelsLast3d, Contiguous");
527   }
528 }
529 
530 
531 template <typename T, typename opmath_t>
532 typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ComputeInternalGradients(int64_t N,int64_t C,int64_t HxW,const T * dY,const T * X,opmath_t * ds,opmath_t * db)533 ComputeInternalGradients(
534     int64_t N,
535     int64_t C,
536     int64_t HxW,
537     const T* dY,
538     const T* X,
539     opmath_t* ds,
540     opmath_t* db) {
541   using Vec = at::vec::Vectorized<opmath_t>;
542   at::parallel_for(0, N * C, 1, [=](int64_t start, int64_t end) {
543     for (const auto i : c10::irange(start, end)) {
544       const T* dY_ptr = dY + i * HxW;
545       const T* X_ptr = X + i * HxW;
546       ds[i] = at::vec::map2_reduce_all<T>(
547           [](Vec x, Vec y) { return x * y; },
548           [](Vec x, Vec y) { return x + y; },
549           dY_ptr,
550           X_ptr,
551           HxW);
552       db[i] = at::vec::reduce_all<T>(
553           [](Vec& x, Vec& y) { return x + y; }, dY_ptr, HxW);
554     }
555   });
556 }
557 
558 template <typename T, typename opmath_t>
559 typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ComputeInternalGradients(int64_t N,int64_t C,int64_t HxW,const T * dY,const T * X,opmath_t * ds,opmath_t * db)560 ComputeInternalGradients(
561     int64_t N,
562     int64_t C,
563     int64_t HxW,
564     const T* dY,
565     const T* X,
566     opmath_t* ds,
567     opmath_t* db) {
568   using Vec = vec::Vectorized<T>;
569   using fVec = vec::Vectorized<opmath_t>;
570   at::parallel_for(0, N * C, 1, [=](int64_t start, int64_t end) {
571     constexpr int64_t K = Vec::size();
572     const int64_t inner_size = HxW / K * K;
573     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
574     std::array<opmath_t, K / 2> ds_arr;
575     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
576     std::array<opmath_t, K / 2> db_arr;
577     for (const auto i : c10::irange(start, end)) {
578       const T* dY_ptr = dY + i * HxW;
579       const T* X_ptr = X + i * HxW;
580       fVec ds_vec(0);
581       fVec db_vec(0);
582       for (int64_t j = 0; j < inner_size; j += K) {
583         const Vec dy_bvec = Vec::loadu(dY_ptr + j);
584         const Vec x_bvec = Vec::loadu(X_ptr + j);
585         auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
586         auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
587         ds_vec = ds_vec + dy_fvec0 * x_fvec0;
588         ds_vec = ds_vec + dy_fvec1 * x_fvec1;
589         db_vec = db_vec + dy_fvec0 + dy_fvec1;
590       }
591       ds_vec.store(ds_arr.data());
592       db_vec.store(db_arr.data());
593       opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
594       opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
595       for (const auto j : c10::irange(inner_size, HxW)) {
596         ds_val += opmath_t(dY_ptr[j]) * opmath_t(X_ptr[j]);
597         db_val += opmath_t(dY_ptr[j]);
598       }
599       ds[i] = ds_val;
600       db[i] = db_val;
601     }
602   });
603 }
604 
605 template <typename PT, typename opmath_t>
606 inline typename std::enable_if<std::is_same<PT, opmath_t>::value, void>::type
CalcDsDb(const opmath_t * ds_ptr,const opmath_t * db_ptr,const PT * gamma_ptr,const int64_t d,const int64_t K,void * ds_arr,void * db_arr)607 CalcDsDb(
608     const opmath_t* ds_ptr,
609     const opmath_t* db_ptr,
610     const PT* gamma_ptr,
611     const int64_t d,
612     const int64_t K,
613     void* ds_arr,
614     void* db_arr) {
615     vec::Vectorized<opmath_t> ds_vec(0);
616     vec::Vectorized<opmath_t> db_vec(0);
617     for (int64_t j = 0; j < d; j += K) {
618       const vec::Vectorized<PT> gamma_vec = (gamma_ptr == nullptr)
619           ? vec::Vectorized<PT>(1)
620           : vec::Vectorized<PT>::loadu(gamma_ptr + j);
621       ds_vec = ds_vec + vec::Vectorized<PT>::loadu(ds_ptr + j) * gamma_vec;
622       db_vec = db_vec + vec::Vectorized<PT>::loadu(db_ptr + j) * gamma_vec;
623     }
624     ds_vec.store(ds_arr);
625     db_vec.store(db_arr);
626 }
627 
628 template <typename PT, typename opmath_t>
629 inline typename std::enable_if<!std::is_same<PT, opmath_t>::value, void>::type
CalcDsDb(const opmath_t * ds_ptr,const opmath_t * db_ptr,const PT * gamma_ptr,const int64_t d,const int64_t K,void * ds_arr,void * db_arr)630 CalcDsDb(
631     const opmath_t* ds_ptr,
632     const opmath_t* db_ptr,
633     const PT* gamma_ptr,
634     const int64_t d,
635     const int64_t K,
636     void* ds_arr,
637     void* db_arr) {
638   using fVec = at::vec::Vectorized<opmath_t>;
639   using Vec = at::vec::Vectorized<PT>;
640   fVec ds_acc(0);
641   fVec db_acc(0);
642   for (int64_t j = 0; j < d; j += K) {
643     const Vec gamma_vec = (gamma_ptr == nullptr) ? Vec(1) : Vec::loadu(gamma_ptr + j);
644     auto [gamma_vec0, gamma_vec1] = convert_to_float<PT>(gamma_vec);
645     ds_acc += fVec::loadu(ds_ptr + j) * gamma_vec0;
646     ds_acc += fVec::loadu(ds_ptr + j + fVec::size()) * gamma_vec1;
647     db_acc += fVec::loadu(db_ptr + j) * gamma_vec0;
648     db_acc += fVec::loadu(db_ptr + j + fVec::size()) * gamma_vec1;
649   }
650   ds_acc.store(ds_arr);
651   db_acc.store(db_arr);
652 }
653 
654 template <typename T, typename PT, typename opmath_t>
GroupNormInputBackward(int64_t N,int64_t C,int64_t HxW,int64_t group,const T * dY,const T * X,const PT * mean,const PT * rstd,const PT * gamma,const opmath_t * ds,const opmath_t * db,T * dX)655 void GroupNormInputBackward(
656     int64_t N,
657     int64_t C,
658     int64_t HxW,
659     int64_t group,
660     const T* dY,
661     const T* X,
662     const PT* mean,
663     const PT* rstd,
664     const PT* gamma,
665     const opmath_t* ds,
666     const opmath_t* db,
667     T* dX) {
668   const int64_t G = group;
669   const int64_t D = C / G;
670   const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
671   const bool gamma_null = (gamma == nullptr);
672   at::parallel_for(0, N * G, 1, [=](int64_t start, int64_t end) {
673     constexpr int64_t K = vec::Vectorized<PT>::size();
674     const int64_t d = D / K * K;
675     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
676     std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> ds_arr;
677     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
678     std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> db_arr;
679     for (const auto i : c10::irange(start, end)) {
680       const int64_t g = i % G;
681       const opmath_t* ds_ptr = ds + i * D;
682       const opmath_t* db_ptr = db + i * D;
683       const PT* gamma_ptr = gamma_null ? nullptr : (gamma + g * D);
684       CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
685       opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
686       opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
687       for (const auto j : c10::irange(d, D)) {
688         const opmath_t gamma_v = gamma_null ? opmath_t(1) : opmath_t(gamma[g * D + j]);
689         ds_val += ds_ptr[j] * gamma_v;
690         db_val += db_ptr[j] * gamma_v;
691       }
692       const opmath_t c2 =
693           (db_val * opmath_t(mean[i]) - ds_val) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * s;
694       const opmath_t c3 = -c2 * opmath_t(mean[i]) - db_val * opmath_t(rstd[i]) * s;
695 
696       for (const auto j : c10::irange(D)) {
697         const int64_t c = g * D + j;
698         const T* dY_ptr = dY + (i * D + j) * HxW;
699         const T* X_ptr = X + (i * D + j) * HxW;
700         T* dX_ptr = dX + (i * D + j) * HxW;
701         const opmath_t c1 = opmath_t(rstd[i]) * (gamma_null ? opmath_t(1) : opmath_t(gamma[c]));
702         for (const auto k : c10::irange(HxW)) {
703           dX_ptr[k] = c1 * opmath_t(dY_ptr[k]) + c2 * opmath_t(X_ptr[k]) + c3;
704         }
705       }
706     }
707   });
708 }
709 
710 template <typename PT, typename opmath_t>
711 typename std::enable_if<std::is_same<PT, opmath_t>::value, void>::type
GammaBackward(int64_t N,int64_t C,int64_t group,const PT * mean,const PT * rstd,const opmath_t * ds,const opmath_t * db,PT * dgamma)712 GammaBackward(
713     int64_t N,
714     int64_t C,
715     int64_t group,
716     const PT* mean,
717     const PT* rstd,
718     const opmath_t* ds,
719     const opmath_t* db,
720     PT* dgamma) {
721   const int64_t G = group;
722   const int64_t D = C / G;
723   constexpr int64_t K = at::vec::Vectorized<PT>::size();
724   using Vec = at::vec::Vectorized<PT>;
725   const int64_t inner_size = D / K * K;
726   for (const auto g : c10::irange(G)) {
727     int64_t i = 0;
728     for (; i < inner_size; i += K) {
729       Vec acc_vec{0};
730       for (const auto n : c10::irange(N)) {
731         const PT* ds_ptr = ds + n * C + g * D + i;
732         const PT* db_ptr = db + n * C + g * D + i;
733         auto ds_vec = Vec::loadu(ds_ptr);
734         auto db_vec = Vec::loadu(db_ptr);
735         auto mean_vec = Vec(mean[n * G + g]);
736         auto rstd_vec = Vec(rstd[n * G + g]);
737         acc_vec += (ds_vec - db_vec * mean_vec) * rstd_vec;
738       }
739       acc_vec.store(dgamma + g * D + i);
740     }
741     if (D - i > 0) {
742       Vec acc_vec{0};
743       for (const auto n : c10::irange(N)) {
744         const PT* ds_ptr = ds + n * C + g * D + i;
745         const PT* db_ptr = db + n * C + g * D + i;
746         auto ds_vec = Vec::loadu(ds_ptr, D - i);
747         auto db_vec = Vec::loadu(db_ptr, D - i);
748         auto mean_vec = Vec(mean[n * G + g]);
749         auto rstd_vec = Vec(rstd[n * G + g]);
750         acc_vec += (ds_vec - db_vec * mean_vec) * rstd_vec;
751       }
752       acc_vec.store(dgamma + g * D + i, D - i);
753     }
754   }
755 }
756 
757 template <typename PT, typename opmath_t>
758 typename std::enable_if<!std::is_same<PT, opmath_t>::value, void>::type
GammaBackward(int64_t N,int64_t C,int64_t group,const PT * mean,const PT * rstd,const opmath_t * ds,const opmath_t * db,PT * dgamma)759 GammaBackward(
760     int64_t N,
761     int64_t C,
762     int64_t group,
763     const PT* mean,
764     const PT* rstd,
765     const opmath_t* ds,
766     const opmath_t* db,
767     PT* dgamma) {
768   const int64_t G = group;
769   const int64_t D = C / G;
770   using Vec = at::vec::Vectorized<PT>;
771   using fVec = at::vec::Vectorized<opmath_t>;
772   constexpr int64_t K = Vec::size();
773   const int64_t inner_size = D / K * K;
774   for (const auto g : c10::irange(G)) {
775     int64_t i = 0;
776     for (; i < inner_size; i += K) {
777       fVec acc0_vec{0}, acc1_vec{0};
778       for (const auto n : c10::irange(N)) {
779         const opmath_t* ds_ptr = ds + n * C + g * D + i;
780         const opmath_t* db_ptr = db + n * C + g * D + i;
781         fVec ds_vec0, ds_vec1, db_vec0, db_vec1;
782         ds_vec0 = fVec::loadu(ds_ptr);
783         ds_vec1 = fVec::loadu(ds_ptr + fVec::size());
784         db_vec0 = fVec::loadu(db_ptr);
785         db_vec1 = fVec::loadu(db_ptr + fVec::size());
786         fVec mean_vec = fVec(opmath_t(mean[n * G + g]));
787         fVec rstd_vec = fVec(opmath_t(rstd[n * G + g]));
788         acc0_vec += (ds_vec0 - db_vec0 * mean_vec) * rstd_vec;
789         acc1_vec += (ds_vec1 - db_vec1 * mean_vec) * rstd_vec;
790       }
791       convert_from_float<PT>(acc0_vec, acc1_vec).store(dgamma + g * D + i);
792     }
793     if (D - i > 0) {
794       fVec acc0_vec{0}, acc1_vec{0};
795       for (const auto n : c10::irange(N)) {
796         const opmath_t* ds_ptr = ds + n * C + g * D + i;
797         const opmath_t* db_ptr = db + n * C + g * D + i;
798         fVec ds_vec0, ds_vec1, db_vec0, db_vec1;
799         ds_vec0 = fVec::loadu(
800             ds_ptr, (D - i) > fVec::size() ? fVec::size() : (D - i));
801         ds_vec1 = fVec::loadu(
802             ds_ptr + fVec::size(),
803             (D - i) > fVec::size() ? (D - i - fVec::size()) : 0);
804         db_vec0 = fVec::loadu(
805             db_ptr, (D - i) > fVec::size() ? fVec::size() : (D - i));
806         db_vec1 = fVec::loadu(
807             db_ptr + fVec::size(),
808             (D - i) > fVec::size() ? (D - i - fVec::size()) : 0);
809         fVec mean_vec = fVec(opmath_t(mean[n * G + g]));
810         fVec rstd_vec = fVec(opmath_t(rstd[n * G + g]));
811         acc0_vec += (ds_vec0 - db_vec0 * mean_vec) * rstd_vec;
812         acc1_vec += (ds_vec1 - db_vec1 * mean_vec) * rstd_vec;
813       }
814       convert_from_float<PT>(acc0_vec, acc1_vec).store(dgamma + g * D + i, D - i);
815     }
816   }
817 }
818 
819 template <typename PT, typename opmath_t>
820 typename std::enable_if<std::is_same<PT, opmath_t>::value, void>::type
BetaBackward(int64_t N,int64_t C,const opmath_t * db,PT * dbeta)821 BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) {
822   using Vec = at::vec::Vectorized<PT>;
823   constexpr int64_t K = Vec::size();
824   Vec acc_vec{0}, zero{0};
825   const int64_t inner_size = C / K * K;
826   int64_t i = 0;
827   for (; i < inner_size; i += K) {
828     for (const auto n : c10::irange(N)) {
829       acc_vec += Vec::loadu(db + n * C + i);
830     }
831     acc_vec.store(dbeta + i);
832     acc_vec = Vec::set(acc_vec, zero);
833   }
834   if (C - i > 0) {
835     for (const auto n : c10::irange(N)) {
836       acc_vec += Vec::loadu(db + n * C + i, C - i);
837     }
838     acc_vec.store(dbeta + i, C - i);
839     acc_vec = Vec::set(acc_vec, zero, C - i);
840   }
841 }
842 
843 template <typename PT, typename opmath_t>
844 typename std::enable_if<!std::is_same<PT, opmath_t>::value, void>::type
BetaBackward(int64_t N,int64_t C,const opmath_t * db,PT * dbeta)845 BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) {
846   using Vec = at::vec::Vectorized<PT>;
847   using fVec = at::vec::Vectorized<opmath_t>;
848   constexpr int64_t K = Vec::size();
849   fVec acc0_vec{0}, acc1_vec{0}, zero{0};
850   const int64_t inner_size = C / K * K;
851   int64_t i = 0;
852   for (; i < inner_size; i += K) {
853     for (const auto n : c10::irange(N)) {
854       fVec db_vec0, db_vec1;
855       db_vec0 = fVec::loadu(db + n * C + i);
856       db_vec1 = fVec::loadu(db + n * C + i + fVec::size());
857       acc0_vec += db_vec0;
858       acc1_vec += db_vec1;
859     }
860     convert_from_float<PT>(acc0_vec, acc1_vec).store(dbeta + i);
861     acc0_vec = fVec::set(acc0_vec, zero);
862     acc1_vec = fVec::set(acc1_vec, zero);
863   }
864   if (C - i > 0) {
865     for (const auto n : c10::irange(N)) {
866       fVec db_vec0, db_vec1;
867       db_vec0 = fVec::loadu(
868           db + n * C + i, (C - i) > fVec::size() ? fVec::size() : (C - i));
869       db_vec1 = fVec::loadu(
870           db + n * C + i + fVec::size(),
871           (C - i) > fVec::size() ? (C - i - fVec::size()) : 0);
872       acc0_vec += db_vec0;
873       acc1_vec += db_vec1;
874     }
875     convert_from_float<PT>(acc0_vec, acc1_vec).store(dbeta + i, C - i);
876     acc0_vec = fVec::set(acc0_vec, zero, C - i);
877     acc1_vec = fVec::set(acc1_vec, zero, C - i);
878   }
879 }
880 
881 template <typename T, typename PT>
GroupNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)882 void GroupNormBackwardKernelImplInternal(
883     const Tensor& dY,
884     const Tensor& X,
885     const Tensor& mean,
886     const Tensor& rstd,
887     const Tensor& gamma,
888     int64_t N,
889     int64_t C,
890     int64_t HxW,
891     int64_t group,
892     Tensor& dX,
893     Tensor& dgamma,
894     Tensor& dbeta) {
895   TORCH_CHECK(dY.numel() == N * C * HxW);
896   TORCH_CHECK(X.numel() == N * C * HxW);
897   TORCH_CHECK(mean.numel() == N * group);
898   TORCH_CHECK(rstd.numel() == N * group);
899   TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
900   const T* dY_data = dY.const_data_ptr<T>();
901   const T* X_data = X.const_data_ptr<T>();
902   const PT* mean_data = mean.const_data_ptr<PT>();
903   const PT* rstd_data = rstd.const_data_ptr<PT>();
904   const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
905   T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
906   PT* dgamma_data = dgamma.defined() ? dgamma.data_ptr<PT>() : nullptr;
907   PT* dbeta_data = dbeta.defined() ? dbeta.data_ptr<PT>() : nullptr;
908   using opmath_t = at::opmath_type<T>;
909   Tensor ds = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
910   Tensor db = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
911   opmath_t* ds_data = ds.data_ptr<opmath_t>();
912   opmath_t* db_data = db.data_ptr<opmath_t>();
913   ComputeInternalGradients<T, opmath_t>(N, C, HxW, dY_data, X_data, ds_data, db_data);
914 
915   if (dX_data != nullptr) {
916     GroupNormInputBackward<T, PT, opmath_t>(
917         N,
918         C,
919         HxW,
920         group,
921         dY_data,
922         X_data,
923         mean_data,
924         rstd_data,
925         gamma_data,
926         ds_data,
927         db_data,
928         dX_data);
929   }
930   if (dgamma_data != nullptr) {
931     GammaBackward(
932         N, C, group, mean_data, rstd_data, ds_data, db_data, dgamma_data);
933   }
934   if (dbeta_data != nullptr) {
935     BetaBackward(N, C, db_data, dbeta_data);
936   }
937 }
938 
939 template <typename T, typename opmath_t>
940 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
DsDbRowwiseMomentsChannelsLast(const T * dY_ptr,const T * X_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t C)941 DsDbRowwiseMomentsChannelsLast(
942   const T* dY_ptr,
943   const T* X_ptr,
944   opmath_t* ds_ptr,
945   opmath_t* db_ptr,
946   int64_t C) {
947   using Vec = vec::Vectorized<T>;
948   constexpr int64_t K = vec::Vectorized<T>::size();
949   const int64_t inner_size = C / K * K;
950   int64_t d = 0;
951   for (; d < inner_size; d += K) {
952     Vec ds_dev = Vec::loadu(ds_ptr + d);
953     Vec db_vec = Vec::loadu(db_ptr + d);
954     Vec x_vec = Vec::loadu(X_ptr + d);
955     Vec dy_vec = Vec::loadu(dY_ptr + d);
956 
957     ds_dev += x_vec * dy_vec;
958     db_vec += dy_vec;
959     ds_dev.store(ds_ptr + d);
960     db_vec.store(db_ptr + d);
961   }
962   if (C - d > 0) {
963     Vec ds_dev = Vec::loadu(ds_ptr + d, C - d);
964     Vec db_vec = Vec::loadu(db_ptr + d, C - d);
965     Vec x_vec = Vec::loadu(X_ptr + d, C - d);
966     Vec dy_vec = Vec::loadu(dY_ptr + d, C - d);
967     ds_dev += x_vec * dy_vec;
968     db_vec += dy_vec;
969     ds_dev.store(ds_ptr + d, C - d);
970     db_vec.store(db_ptr + d, C - d);
971   }
972 }
973 
974 template <typename T, typename opmath_t>
975 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
DsDbRowwiseMomentsChannelsLast(const T * dY_ptr,const T * X_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t C)976 DsDbRowwiseMomentsChannelsLast(
977   const T* dY_ptr,
978   const T* X_ptr,
979   opmath_t* ds_ptr,
980   opmath_t* db_ptr,
981   int64_t C) {
982   using fVec = vec::Vectorized<opmath_t>;
983   using Vec = vec::Vectorized<T>;
984   int64_t d = 0;
985   for (; d < C - (C % Vec::size()); d += Vec::size()) {
986     fVec ds_dev0 = fVec::loadu(ds_ptr + d);
987     fVec ds_dev1 = fVec::loadu(ds_ptr + d + fVec::size());
988     fVec db_vec0 = fVec::loadu(db_ptr + d);
989     fVec db_vec1 = fVec::loadu(db_ptr + d + fVec::size());
990     Vec x_vec = Vec::loadu(X_ptr + d);
991     Vec dy_vec = Vec::loadu(dY_ptr + d);
992     auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
993     auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
994     ds_dev0 += x_vec0 * dy_vec0;
995     ds_dev1 += x_vec1 * dy_vec1;
996     db_vec0 += dy_vec0;
997     db_vec1 += dy_vec1;
998 
999     ds_dev0.store(ds_ptr + d);
1000     ds_dev1.store(ds_ptr + d + fVec::size());
1001     db_vec0.store(db_ptr + d);
1002     db_vec1.store(db_ptr + d + fVec::size());
1003 
1004   }
1005   if (C - d > 0) {
1006     fVec ds_dev0 = fVec::loadu(ds_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1007     fVec ds_dev1 = fVec::loadu(ds_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1008     fVec db_vec0 = fVec::loadu(db_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1009     fVec db_vec1 = fVec::loadu(db_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1010     Vec x_vec = Vec::loadu(X_ptr + d, C - d);
1011     Vec dy_vec = Vec::loadu(dY_ptr + d, C - d);
1012     auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1013     auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1014     ds_dev0 += x_vec0 * dy_vec0;
1015     ds_dev1 += x_vec1 * dy_vec1;
1016     db_vec0 += dy_vec0;
1017     db_vec1 += dy_vec1;
1018 
1019     ds_dev0.store(ds_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1020     ds_dev1.store(ds_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1021     db_vec0.store(db_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1022     db_vec1.store(db_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1023   }
1024 }
1025 
1026 template <typename T>
1027 inline typename std::enable_if<std::is_same<T, at::opmath_type<T>>::value,
1028   std::tuple<
1029   vec::Vectorized<T>,
1030   vec::Vectorized<T>>>::type
load_util(const T * data_ptr,int64_t n)1031 load_util(const T* data_ptr, int64_t n) {
1032   using Vec = vec::Vectorized<T>;
1033   auto vec0 = Vec::loadu(data_ptr, n > Vec::size() ? Vec::size() : n);
1034   auto vec1 = Vec::loadu(
1035       data_ptr + Vec::size(), n > Vec::size() ? (n - Vec::size()) : 0);
1036   return std::tuple<Vec, Vec>(vec0, vec1);
1037 }
1038 
1039 template <typename T>
1040 inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value,
1041   std::tuple<
1042     vec::Vectorized<at::opmath_type<T>>,
1043     vec::Vectorized<at::opmath_type<T>>>
1044     >::type
load_util(const T * data_ptr,int64_t n)1045 load_util(const T* data_ptr, int64_t n) {
1046   using Vec = vec::Vectorized<T>;
1047   auto vec = Vec::loadu(data_ptr, n);
1048   return convert_to_float<T>(vec);
1049 }
1050 
1051 template <typename T, typename PT, typename opmath_t>
1052 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastColMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1053 ApplyInputGradientsChannelsLastColMov(
1054   const T* dY_data,
1055   const T* X_data,
1056   T* dX_data,
1057   const PT* rstd,
1058   const PT* gamma,
1059   opmath_t c2,
1060   opmath_t c3,
1061   int64_t HxW,
1062   int64_t C,
1063   int64_t D) {
1064   const bool gamma_null = (gamma == nullptr);
1065   int64_t d = 0;
1066   auto K = vec::Vectorized<T>::size();
1067   for (; d < D / K * K; d += K) {
1068     auto c1 = vec::Vectorized<T>(*rstd) *
1069         (gamma_null ? vec::Vectorized<T>(1)
1070                     : vec::Vectorized<T>::loadu(gamma + d));
1071     for (const auto m : c10::irange(HxW)) {
1072       const T* X_ptr = X_data + m * C;
1073       const T* dY_ptr = dY_data + m * C;
1074       T* dX_ptr = dX_data + m * C;
1075       auto dy_vec = vec::Vectorized<T>::loadu(dY_ptr + d);
1076       auto x_vec = vec::Vectorized<T>::loadu(X_ptr + d);
1077       auto dx_vec = c1 * dy_vec +
1078         vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1079       dx_vec.store(dX_ptr + d);
1080     }
1081   }
1082   if (D - d > 0) {
1083     auto c1 = vec::Vectorized<T>(*rstd) *
1084         (gamma_null ? vec::Vectorized<T>(1)
1085                     : vec::Vectorized<T>::loadu(gamma + d, D - d));
1086     for (const auto m : c10::irange(HxW)) {
1087       const T* X_ptr = X_data + m * C;
1088       const T* dY_ptr = dY_data + m * C;
1089       T* dX_ptr = dX_data + m * C;
1090     auto dy_vec = vec::Vectorized<T>::loadu(dY_ptr + d, D - d);
1091     auto x_vec = vec::Vectorized<T>::loadu(X_ptr + d, D - d);
1092     auto dx_vec = c1 * dy_vec +
1093       vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1094     dx_vec.store(dX_ptr + d, D - d);
1095     }
1096   }
1097 }
1098 
1099 template <typename T, typename PT, typename opmath_t>
1100 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastColMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1101 ApplyInputGradientsChannelsLastColMov(
1102     const T* dY_data,
1103     const T* X_data,
1104     T* dX_data,
1105     const PT* rstd,
1106     const PT* gamma,
1107     opmath_t c2,
1108     opmath_t c3,
1109     int64_t HxW,
1110     int64_t C,
1111     int64_t D) {
1112   using Vec = vec::Vectorized<T>;
1113   using fVec = vec::Vectorized<opmath_t>;
1114   const bool gamma_null = (gamma == nullptr);
1115   auto K = Vec::size();
1116   int64_t d = 0;
1117   for (; d < D / K * K; d += K) {
1118     auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1119                                       : load_util(gamma + d, K);
1120     c1_0 = c1_0 * fVec(opmath_t(*rstd));
1121     c1_1 = c1_1 * fVec(opmath_t(*rstd));
1122     for (const auto m : c10::irange(HxW)) {
1123       const T* X_ptr = X_data + m * C;
1124       const T* dY_ptr = dY_data + m * C;
1125       T* dX_ptr = dX_data + m * C;
1126 
1127       Vec dy_vec = Vec::loadu(dY_ptr + d);
1128       Vec x_vec = Vec::loadu(X_ptr + d);
1129       auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1130       auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1131       fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1132       fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1133       convert_from_float<T>(dx_vec0, dx_vec1).store(dX_ptr + d);
1134     }
1135   }
1136   if (D - d > 0) {
1137     auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1138                                       : load_util(gamma + d, D - d);
1139     c1_0 = c1_0 * fVec(opmath_t(*rstd));
1140     c1_1 = c1_1 * fVec(opmath_t(*rstd));
1141     for (const auto m : c10::irange(HxW)) {
1142       const T* X_ptr = X_data + m * C;
1143       const T* dY_ptr = dY_data + m * C;
1144       T* dX_ptr = dX_data + m * C;
1145       Vec dy_vec = Vec::loadu(dY_ptr + d, D - d);
1146       Vec x_vec = Vec::loadu(X_ptr + d, D - d);
1147       auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1148       auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1149       fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1150       fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1151       convert_from_float<T>(dx_vec0, dx_vec1).store(dX_ptr + d, D - d);
1152     }
1153   }
1154 }
1155 
1156 template <typename T, typename PT, typename opmath_t>
1157 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastRowMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1158 ApplyInputGradientsChannelsLastRowMov(
1159   const T* dY_data,
1160   const T* X_data,
1161   T* dX_data,
1162   const PT* rstd,
1163   const PT* gamma,
1164   opmath_t c2,
1165   opmath_t c3,
1166   int64_t HxW,
1167   int64_t C,
1168   int64_t D) {
1169   const bool gamma_null = (gamma == nullptr);
1170   int64_t d = 0;
1171   auto K = vec::Vectorized<T>::size();
1172   for (; d < D / K * K; d += K) {
1173     auto c1 = vec::Vectorized<T>(*rstd) *
1174       (gamma_null ? vec::Vectorized<T>(1) : vec::Vectorized<T>::loadu(gamma + d));
1175     auto dy_vec = vec::Vectorized<T>::loadu(dY_data + d);
1176     auto x_vec = vec::Vectorized<T>::loadu(X_data + d);
1177     auto dx_vec = c1 * dy_vec +
1178       vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1179     dx_vec.store(dX_data + d);
1180   }
1181   if (D - d > 0) {
1182     auto c1 = vec::Vectorized<T>(*rstd) *
1183       (gamma_null ? vec::Vectorized<T>(1) : vec::Vectorized<T>::loadu(gamma + d, D - d));
1184     auto dy_vec = vec::Vectorized<T>::loadu(dY_data + d, D - d);
1185     auto x_vec = vec::Vectorized<T>::loadu(X_data + d, D - d);
1186     auto dx_vec = c1 * dy_vec +
1187       vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1188     dx_vec.store(dX_data + d, D - d);
1189   }
1190 }
1191 
1192 template <typename T, typename PT, typename opmath_t>
1193 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastRowMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1194 ApplyInputGradientsChannelsLastRowMov(
1195     const T* dY_data,
1196     const T* X_data,
1197     T* dX_data,
1198     const PT* rstd,
1199     const PT* gamma,
1200     opmath_t c2,
1201     opmath_t c3,
1202     int64_t HxW,
1203     int64_t C,
1204     int64_t D) {
1205   using Vec = vec::Vectorized<T>;
1206   using fVec = vec::Vectorized<opmath_t>;
1207   const bool gamma_null = (gamma == nullptr);
1208   auto K = Vec::size();
1209   int64_t d = 0;
1210   for (; d < D / K * K; d += K) {
1211     auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1212                                       : load_util(gamma + d, K);
1213     c1_0 = c1_0 * fVec(opmath_t(*rstd));
1214     c1_1 = c1_1 * fVec(opmath_t(*rstd));
1215     Vec dy_vec = Vec::loadu(dY_data + d);
1216     Vec x_vec = Vec::loadu(X_data + d);
1217     auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1218     auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1219     fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1220     fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1221     convert_from_float<T>(dx_vec0, dx_vec1).store(dX_data + d);
1222   }
1223   if (D - d > 0) {
1224     auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1225                                       : load_util(gamma + d, D - d);
1226     c1_0 = c1_0 * fVec(opmath_t(*rstd));
1227     c1_1 = c1_1 * fVec(opmath_t(*rstd));
1228     Vec dy_vec = Vec::loadu(dY_data + d, D - d);
1229     Vec x_vec = Vec::loadu(X_data + d, D - d);
1230     auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1231     auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1232     fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1233     fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1234     convert_from_float<T>(dx_vec0, dx_vec1).store(dX_data + d, D - d);
1235   }
1236 }
1237 
1238 template <typename T, typename PT, typename opmath_t>
1239 inline typename std::
1240     enable_if<std::is_same<T, opmath_t>::value, std::tuple<opmath_t, opmath_t>>::type
CalcInternalGradientsChannelsLast(const T * X_data,const T * dY_data,const PT * gamma_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t HxW,int64_t C,int64_t D)1241     CalcInternalGradientsChannelsLast(
1242     const T* X_data,
1243     const T* dY_data,
1244     const PT* gamma_ptr,
1245     opmath_t* ds_ptr,
1246     opmath_t* db_ptr,
1247     int64_t HxW,
1248     int64_t C,
1249     int64_t D) {
1250   using Vec = vec::Vectorized<T>;
1251   const bool gamma_null = (gamma_ptr == nullptr);
1252   constexpr int64_t K = Vec::size();
1253   const int64_t inner_size = D / K * K;
1254   int64_t d = 0;
1255   opmath_t ds_gamma{0}, db_gamma{0};
1256   for (; d < inner_size; d += K) {
1257     Vec acc0_vec{0}, acc1_vec{0};
1258     for (const auto m : c10::irange(HxW)) {
1259       const T* X_ptr = X_data + m * C;
1260       const T* dY_ptr = dY_data + m * C;
1261       Vec x_vec = Vec::loadu(X_ptr + d);
1262       Vec dy_vec = Vec::loadu(dY_ptr + d);
1263       acc0_vec += x_vec * dy_vec;
1264       acc1_vec += dy_vec;
1265     }
1266     acc0_vec.store(ds_ptr + d);
1267     acc1_vec.store(db_ptr + d);
1268     ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1269       acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
1270     db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1271       acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
1272   }
1273   if (D - d > 0) {
1274     Vec acc0_vec{0}, acc1_vec{0};
1275     for (const auto m : c10::irange(HxW)) {
1276       const T* X_ptr = X_data + m * C;
1277       const T* dY_ptr = dY_data + m * C;
1278       Vec x_vec = Vec::loadu(X_ptr + d, D - d);
1279       Vec dy_vec = Vec::loadu(dY_ptr + d, D - d);
1280       acc0_vec += x_vec * dy_vec;
1281       acc1_vec += dy_vec;
1282     }
1283     acc0_vec.store(ds_ptr + d, D - d);
1284     acc1_vec.store(db_ptr + d, D - d);
1285     ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1286       acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
1287     db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1288       acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
1289   }
1290   return std::tuple<opmath_t, opmath_t>(ds_gamma, db_gamma);
1291 }
1292 
1293 template <typename T, typename PT, typename opmath_t>
1294 inline typename std::
1295     enable_if<!std::is_same<T, opmath_t>::value, std::tuple<opmath_t, opmath_t>>::type
CalcInternalGradientsChannelsLast(const T * X_data,const T * dY_data,const PT * gamma_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t HxW,int64_t C,int64_t D)1296     CalcInternalGradientsChannelsLast(
1297         const T* X_data,
1298         const T* dY_data,
1299         const PT* gamma_ptr,
1300         opmath_t* ds_ptr,
1301         opmath_t* db_ptr,
1302         int64_t HxW,
1303         int64_t C,
1304         int64_t D) {
1305   using Vec = vec::Vectorized<T>;
1306   using fVec = vec::Vectorized<opmath_t>;
1307   const bool gamma_null = (gamma_ptr == nullptr);
1308   constexpr int64_t K = Vec::size();
1309   const int64_t inner_size = D / K * K;
1310   float ds_gamma{0}, db_gamma{0};
1311   int64_t d = 0;
1312   for (; d < inner_size; d += K) {
1313     fVec acc0_vec0{0}, acc0_vec1{0}, acc1_vec0{0}, acc1_vec1{0};
1314     for (const auto m : c10::irange(HxW)) {
1315       const T* X_ptr = X_data + m * C;
1316       const T* dY_ptr = dY_data + m * C;
1317       Vec x_vec = Vec::loadu(X_ptr + d);
1318       Vec dy_vec = Vec::loadu(dY_ptr + d);
1319       auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1320       auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1321       acc0_vec0 += x_vec0 * dy_vec0;
1322       acc0_vec1 += x_vec1 * dy_vec1;
1323       acc1_vec0 += dy_vec0;
1324       acc1_vec1 += dy_vec1;
1325     }
1326     acc0_vec0.store(ds_ptr + d);
1327     acc0_vec1.store(ds_ptr + d + fVec::size());
1328     acc1_vec0.store(db_ptr + d);
1329     acc1_vec1.store(db_ptr + d + fVec::size());
1330     auto [gamma_vec0, gamma_vec1] = gamma_null ?
1331       std::tuple<fVec, fVec>(fVec(1), fVec(1)) : load_util(gamma_ptr + d, K);
1332     ds_gamma += vec::vec_reduce_all(
1333         [](fVec& x, fVec& y) { return x + y; }, acc0_vec0 * gamma_vec0);
1334     ds_gamma += vec::vec_reduce_all(
1335         [](fVec& x, fVec& y) { return x + y; }, acc0_vec1 * gamma_vec1);
1336     db_gamma += vec::vec_reduce_all(
1337         [](fVec& x, fVec& y) { return x + y; }, acc1_vec0 * gamma_vec0);
1338     db_gamma += vec::vec_reduce_all(
1339         [](fVec& x, fVec& y) { return x + y; }, acc1_vec1 * gamma_vec1);
1340   }
1341   for (; d < D; d++) {
1342     opmath_t acc0{0}, acc1{0};
1343     for (const auto m : c10::irange(HxW)) {
1344       const T* X_ptr = X_data + m * C;
1345       const T* dY_ptr = dY_data + m * C;
1346       acc0 += opmath_t(X_ptr[d]) * opmath_t(dY_ptr[d]);
1347       acc1 += opmath_t(dY_ptr[d]);
1348     }
1349     ds_ptr[d] = acc0;
1350     db_ptr[d] = acc1;
1351     opmath_t gamma_val = gamma_null ? opmath_t(1) : opmath_t(gamma_ptr[d]);
1352     ds_gamma += acc0 * gamma_val;
1353     db_gamma += acc1 * gamma_val;
1354   }
1355 
1356   return std::tuple<opmath_t, opmath_t>(ds_gamma, db_gamma);
1357 }
1358 
1359 template <typename T, typename PT>
GroupNormBackwardKernelImplChannelsLastInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)1360 void GroupNormBackwardKernelImplChannelsLastInternal(
1361     const Tensor& dY,
1362     const Tensor& X,
1363     const Tensor& mean,
1364     const Tensor& rstd,
1365     const Tensor& gamma,
1366     int64_t N,
1367     int64_t C,
1368     int64_t HxW,
1369     int64_t group,
1370     Tensor& dX,
1371     Tensor& dgamma,
1372     Tensor& dbeta) {
1373   TORCH_CHECK(dY.numel() == N * C * HxW);
1374   TORCH_CHECK(X.numel() == N * C * HxW);
1375   TORCH_CHECK(mean.numel() == N * group);
1376   TORCH_CHECK(rstd.numel() == N * group);
1377   TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
1378   int64_t D = C / group;
1379   int64_t G = group;
1380   const T* dY_data = dY.const_data_ptr<T>();
1381   const T* X_data = X.const_data_ptr<T>();
1382   const PT* mean_data = mean.const_data_ptr<PT>();
1383   const PT* rstd_data = rstd.const_data_ptr<PT>();
1384   const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
1385   T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
1386   PT* dgamma_data = dgamma.defined() ? dgamma.data_ptr<PT>() : nullptr;
1387   PT* dbeta_data = dbeta.defined() ? dbeta.data_ptr<PT>() : nullptr;
1388   const bool gamma_null = (gamma_data == nullptr);
1389   using opmath_t = at::opmath_type<T>;
1390   Tensor ds = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
1391   Tensor db = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
1392   opmath_t* ds_data = ds.data_ptr<opmath_t>();
1393   opmath_t* db_data = db.data_ptr<opmath_t>();
1394   const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
1395 
1396   // Similar to channels last forward, channels last backward has also 2 impls.
1397   // impl-1: parallel on N * G. Only need one omp session for input gradients
1398   //   but memory access per thread is non-contiguous.
1399   //
1400   // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
1401   //   but requires help of extra temp buffer of size {T, N, 2C}.
1402 
1403   // Generally impl-2 has better performance when HxW is large enough, so that
1404   //   data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
1405   constexpr int64_t feature_map_threshold = 2048;
1406   if (HxW < feature_map_threshold) {
1407     // impl-1: parallel on N * G.
1408     at::parallel_for(0, N * G, 1, [=](int64_t begin, int64_t end) {
1409       int64_t n{0}, g{0};
1410       data_index_init(begin, n, N, g, G);
1411       for (const auto i : c10::irange(begin, end)) {
1412         // Step 1. Compute internal gradients.
1413         opmath_t* ds_ptr = ds_data + i * D;
1414         opmath_t* db_ptr = db_data + i * D;
1415         const T* X_ptr = X_data + n * HxW * C + g * D;
1416         const T* dY_ptr = dY_data + n * HxW * C + g * D;
1417         const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
1418         auto [ds_gamma, db_gamma] = CalcInternalGradientsChannelsLast<T, PT, opmath_t>(
1419           X_ptr,
1420           dY_ptr,
1421           gamma_ptr,
1422           ds_ptr,
1423           db_ptr,
1424           HxW,
1425           C,
1426           D);
1427 
1428         // Step 2. Compute dX.
1429         T* dX_ptr = dX_data + n * HxW * C + g * D;
1430         const PT* rstd_ptr = rstd_data + i;
1431         const opmath_t c2 = (db_gamma * opmath_t(mean_data[i]) - ds_gamma) *
1432             opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * s;
1433         const opmath_t c3 = -c2 * opmath_t(mean_data[i]) - db_gamma * opmath_t(rstd_data[i]) * s;
1434         ApplyInputGradientsChannelsLastColMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
1435         data_index_step(n, N, g, G);
1436       }
1437     });
1438 
1439   } else {
1440     // impl-2: parallel on N * HxW.
1441     int num_threads = at::get_num_threads();
1442     Tensor buffer = at::empty({num_threads, N, 2 * C},
1443       X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
1444     opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
1445 
1446     Tensor tmp_buffer = at::empty({N, 2 * G},
1447       X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
1448     opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();
1449 
1450     // Step 1. Each thread compute their own internal gradients to the buffer.
1451     at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
1452       int tid = at::get_thread_num();
1453       opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;
1454       int64_t n{0}, m{0};
1455       data_index_init(begin, n, N, m, HxW);
1456       for (const auto i : c10::irange(begin, end)) {
1457         opmath_t* ds_ptr = buffer_ptr + n * 2 * C;
1458         opmath_t* db_ptr = ds_ptr + C;
1459         const T* X_ptr = X_data + i * C;
1460         const T* dY_ptr = dY_data + i * C;
1461 
1462         DsDbRowwiseMomentsChannelsLast<T, opmath_t>(dY_ptr, X_ptr, ds_ptr, db_ptr, C);
1463         data_index_step(n, N, m, HxW);
1464       }
1465     });
1466 
1467     // Step 2. Collect internal gradients from each thread and
1468     // get the final internal gradients to ds, db, and tmp_buffer.
1469     for (const auto n : c10::irange(N)) {
1470       for (const auto g : c10::irange(G)) {
1471         opmath_t ds_gamma{0}, db_gamma{0};
1472         for (const auto d : c10::irange(D)) {
1473           opmath_t ds_val{0}, db_val{0};
1474           for (const auto t : c10::irange(num_threads)) {
1475             opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
1476             opmath_t gamma_val = gamma_null ? opmath_t(1) : opmath_t(gamma_data[g * D + d]);
1477             ds_gamma += buffer_ptr[g * D + d] * gamma_val;
1478             db_gamma += buffer_ptr[g * D + d + C] * gamma_val;
1479             ds_val += buffer_ptr[g * D + d];
1480             db_val += buffer_ptr[g * D + d + C];
1481 
1482             }
1483           ds_data[n * C + g * D + d] = ds_val;
1484           db_data[n * C + g * D + d] = db_val;
1485         }
1486         tmp_buffer_data[n * 2 * G + 2 * g] = ds_gamma;
1487         tmp_buffer_data[n * 2 * G + 2 * g + 1] = db_gamma;
1488       }
1489     }
1490 
1491     // Step 3. Compute dx.
1492     if (dX_data != nullptr) {
1493       at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
1494         int64_t n{0}, m{0};
1495         data_index_init(begin, n, N, m, HxW);
1496         for (const auto i : c10::irange(begin, end)) {
1497           for (const auto g : c10::irange(G)) {
1498             const T* X_ptr = X_data + i * C + g * D;
1499             const T* dY_ptr = dY_data + i * C + g * D;
1500             T* dX_ptr = dX_data + i * C + g * D;
1501             const PT* mean_ptr = mean_data + n * G + g;
1502             const PT* rstd_ptr = rstd_data + n * G + g;
1503             const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
1504             opmath_t ds_val = tmp_buffer_data[n * 2 * G + 2 * g];
1505             opmath_t db_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];
1506 
1507             const opmath_t c2 = (db_val * opmath_t(*mean_ptr) - ds_val) *
1508                 opmath_t(*rstd_ptr) * opmath_t(*rstd_ptr)* opmath_t(*rstd_ptr) * s;
1509             const opmath_t c3 = -c2 * opmath_t(*mean_ptr) - db_val * opmath_t(*rstd_ptr) * s;
1510             ApplyInputGradientsChannelsLastRowMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
1511           }
1512 
1513           data_index_step(n, N, m, HxW);
1514         }
1515       });
1516     }
1517 
1518   }
1519 
1520   // Finally compute dgamma and dbeta.
1521   if (dgamma_data != nullptr) {
1522     GammaBackward(
1523         N, C, group, mean_data, rstd_data, ds_data, db_data, dgamma_data);
1524   }
1525   if (dbeta_data != nullptr) {
1526     BetaBackward(N, C, db_data, dbeta_data);
1527   }
1528 }
1529 
GroupNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)1530 void GroupNormBackwardKernelImpl(
1531     const Tensor& dY,
1532     const Tensor& X,
1533     const Tensor& mean,
1534     const Tensor& rstd,
1535     const Tensor& gamma,
1536     int64_t N,
1537     int64_t C,
1538     int64_t HxW,
1539     int64_t group,
1540     Tensor& dX,
1541     Tensor& dgamma,
1542     Tensor& dbeta) {
1543   // In training, using Amp to enable lower precision data type,
1544   // i.e., BFloat16 or Half, is recommended.
1545   // It will keep module parameters in opmath dtype i.e. float
1546   // while input/output will be in lower precision data type.
1547   // Using parameters in BFloat16 or Half may cause high precision loss.
1548   const bool mixed_type = is_mixed_type(dY, mean);
1549   switch (X.suggest_memory_format()) {
1550     case at::MemoryFormat::Contiguous: {
1551       AT_DISPATCH_FLOATING_TYPES_AND2(
1552         ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() {
1553         using param_t = at::opmath_type<scalar_t>;
1554         if(mixed_type) {
1555           GroupNormBackwardKernelImplInternal<scalar_t, param_t>(
1556               dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1557         } else {
1558           GroupNormBackwardKernelImplInternal<scalar_t, scalar_t>(
1559               dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1560         }
1561       });
1562       break;
1563     }
1564     case at::MemoryFormat::ChannelsLast:
1565     case at::MemoryFormat::ChannelsLast3d: {
1566       AT_DISPATCH_FLOATING_TYPES_AND2(
1567         ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() {
1568         using param_t = at::opmath_type<scalar_t>;
1569         if(mixed_type) {
1570           GroupNormBackwardKernelImplChannelsLastInternal<scalar_t, param_t>(
1571               dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1572         } else {
1573           GroupNormBackwardKernelImplChannelsLastInternal<scalar_t, scalar_t>(
1574               dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1575         }
1576       });
1577       break;
1578     }
1579     default:
1580       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, ChannelsLast3d, Contiguous");
1581   }
1582 
1583 }
1584 
1585 } // namespace
1586 
1587 REGISTER_DISPATCH(GroupNormKernel, &GroupNormKernelImpl);
1588 REGISTER_DISPATCH(GroupNormBackwardKernel, &GroupNormBackwardKernelImpl);
1589 
1590 } // namespace at::native
1591