1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/group_norm.h>
3
4 #include <algorithm>
5 #include <array>
6 #include <numeric>
7
8 #include <ATen/core/Tensor.h>
9 #include <ATen/Dispatch.h>
10 #include <ATen/cpu/vec/vec.h>
11 #include <ATen/cpu/vec/functional.h>
12 #include <ATen/native/cpu/utils.h>
13 #include <ATen/native/cpu/moments_utils.h>
14 #include <ATen/native/cpu/mixed_data_type.h>
15 #include <ATen/OpMathType.h>
16 #include <c10/util/irange.h>
17
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #endif
23
24 namespace at::native {
25
26 namespace {
27
28 template <typename T, typename PT>
GroupNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)29 void GroupNormKernelImplInternal(
30 const Tensor& X,
31 const Tensor& gamma,
32 const Tensor& beta,
33 int64_t N,
34 int64_t C,
35 int64_t HxW,
36 int64_t group,
37 double eps,
38 Tensor& Y,
39 Tensor& mean,
40 Tensor& rstd) {
41 TORCH_CHECK(X.numel() == N * C * HxW);
42 TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
43 TORCH_CHECK(!beta.defined() || beta.numel() == C);
44 const int64_t G = group;
45 const int64_t D = C / G;
46 const T* X_data = X.const_data_ptr<T>();
47 const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
48 const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
49 T* Y_data = Y.data_ptr<T>();
50 PT* mean_data = mean.data_ptr<PT>();
51 PT* rstd_data = rstd.data_ptr<PT>();
52 const bool gamma_null = (gamma_data == nullptr);
53 const bool beta_null = beta_data == nullptr;
54 const int64_t inner_size = D * HxW;
55
56 using opmath_t = at::opmath_type<T>;
57
58 at::parallel_for(0, N * G, 1, [&](int64_t start, int64_t end) {
59 for (const auto i : c10::irange(start, end)) {
60 const T* X_ptr = X_data + i * inner_size;
61 auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, inner_size);
62 rstd_val = opmath_t(1) / std::sqrt(std::max(rstd_val, opmath_t(0)) + eps);
63 if (gamma_null && beta_null) {
64 T* Y_ptr = Y_data + i * inner_size;
65 for (const auto j : c10::irange(inner_size)) {
66 Y_ptr[j] = (X_ptr[j] - mean_val) * rstd_val;
67 }
68 } else {
69 const int64_t g = i % G;
70 for (const auto j : c10::irange(D)) {
71 const int64_t c = g * D + j;
72 const opmath_t scale = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
73 const opmath_t bias = -scale * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
74 X_ptr = X_data + (i * D + j) * HxW;
75 T* Y_ptr = Y_data + (i * D + j) * HxW;
76 for (const auto k : c10::irange(HxW)) {
77 Y_ptr[k] = scale * X_ptr[k] + bias;
78 }
79 }
80 }
81 mean_data[i] = mean_val;
82 rstd_data[i] = rstd_val;
83 }
84 });
85 }
86
87 template <typename T>
88 typename std::enable_if<std::is_same<T, at::opmath_type<T>>::value,
89 std::tuple<T, T>>::type
ColumnwiseMoments(const T * X_data,int64_t HxW,int64_t C,int64_t D)90 ColumnwiseMoments(
91 const T* X_data,
92 int64_t HxW,
93 int64_t C,
94 int64_t D) {
95 using Vec = vec::Vectorized<T>;
96 constexpr int64_t K = Vec::size();
97 const int64_t inner_size = D / K * K;
98 Vec acc0_vec{0}, acc1_vec{0};
99 for (const auto m : c10::irange(HxW)) {
100 const T* X_ptr = X_data + m * C;
101 int64_t d = 0;
102 for (; d < inner_size; d += K) {
103 Vec x_vec = Vec::loadu(X_ptr + d);
104 acc0_vec += x_vec;
105 acc1_vec += x_vec * x_vec;
106 }
107 if (D - d > 0) {
108 Vec x_vec = Vec::loadu(X_ptr + d, D - d);
109 acc0_vec += x_vec;
110 acc1_vec += x_vec * x_vec;
111 }
112 }
113 T mean_val = vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; }, acc0_vec);
114 T rstd_val = vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; }, acc1_vec);
115 return std::tuple<T, T>(mean_val, rstd_val);
116 }
117
118
119 // std::is_same<T, at::BFloat16> || std::is_same<T, at::Half>
120 template <typename T>
121 typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value,
122 std::tuple<at::opmath_type<T>, at::opmath_type<T>>>::type
ColumnwiseMoments(const T * X_data,int64_t HxW,int64_t C,int64_t D)123 ColumnwiseMoments(
124 const T* X_data,
125 int64_t HxW,
126 int64_t C,
127 int64_t D) {
128 using opmath_t = at::opmath_type<T>;
129 using Vec = vec::Vectorized<T>;
130 using fVec = vec::Vectorized<opmath_t>;
131 constexpr int64_t K = Vec::size();
132 const int64_t inner_size = D / K * K;
133 fVec acc0_fvec{0}, acc1_fvec{0}, zero{0};
134 for (const auto m : c10::irange(HxW)) {
135 const T* X_ptr = X_data + m * C;
136 int64_t d = 0;
137 for (; d < inner_size; d += K) {
138 Vec x_bvec = Vec::loadu(X_ptr + d);
139 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
140 acc0_fvec += x_fvec0 + x_fvec1;
141 acc1_fvec += x_fvec0 * x_fvec0 + x_fvec1 * x_fvec1;
142 }
143 if (D - d > 0) {
144 Vec x_bvec = Vec::loadu(X_ptr + d, D - d);
145 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
146 if (D - d > fVec::size()) {
147 x_fvec1 = fVec::set(zero, x_fvec1, D - d - fVec::size());
148 acc0_fvec += x_fvec0 + x_fvec1;
149 acc1_fvec += x_fvec0 * x_fvec0 + x_fvec1 * x_fvec1;
150 } else {
151 x_fvec0 = fVec::set(zero, x_fvec0, D - d);
152 acc0_fvec += x_fvec0;
153 acc1_fvec += x_fvec0 * x_fvec0;
154 }
155 }
156 }
157 opmath_t mean_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, acc0_fvec);
158 opmath_t rstd_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, acc1_fvec);
159 return std::tuple<opmath_t, opmath_t>(mean_val, rstd_val);
160 }
161
162 template <typename T, typename opmath_t>
163 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
CalcMeanVar(const T * X_ptr,opmath_t * mean_ptr,opmath_t * rstd_ptr,int64_t C)164 CalcMeanVar(
165 const T* X_ptr,
166 opmath_t* mean_ptr,
167 opmath_t* rstd_ptr,
168 int64_t C) {
169 using Vec = vec::Vectorized<T>;
170 vec::map2<T>(
171 [](Vec x, Vec y) { return x + y; },
172 mean_ptr,
173 X_ptr,
174 mean_ptr,
175 C);
176 vec::map2<T>(
177 [](Vec x, Vec y) { return x * x + y; },
178 rstd_ptr,
179 X_ptr,
180 rstd_ptr,
181 C);
182 }
183
184 // std::is_same<T, at::BFloat16> || std::is_same<T, at::Half>
185 template <typename T, typename opmath_t>
186 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
CalcMeanVar(const T * X_ptr,opmath_t * mean_ptr,opmath_t * rstd_ptr,int64_t C)187 CalcMeanVar(
188 const T* X_ptr,
189 opmath_t* mean_ptr,
190 opmath_t* rstd_ptr,
191 int64_t C) {
192 using fVec = vec::Vectorized<opmath_t>;
193 using Vec = vec::Vectorized<T>;
194 int64_t d = 0;
195 for (; d < C - (C % Vec::size()); d += Vec::size()) {
196 Vec data_bvec = Vec::loadu(X_ptr + d);
197 fVec mean_fvec0 = fVec::loadu(mean_ptr + d);
198 fVec mean_fvec1 = fVec::loadu(mean_ptr + d + fVec::size());
199 fVec rstd_fvec0 = fVec::loadu(rstd_ptr + d);
200 fVec rstd_fvec1 = fVec::loadu(rstd_ptr + d + fVec::size());
201 auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
202 mean_fvec0 = data_fvec0 + mean_fvec0;
203 mean_fvec1 = data_fvec1 + mean_fvec1;
204 rstd_fvec0 = data_fvec0 * data_fvec0 + rstd_fvec0;
205 rstd_fvec1 = data_fvec1 * data_fvec1 + rstd_fvec1;
206 mean_fvec0.store(mean_ptr + d);
207 mean_fvec1.store(mean_ptr + d + fVec::size());
208 rstd_fvec0.store(rstd_ptr + d);
209 rstd_fvec1.store(rstd_ptr + d + fVec::size());
210 }
211 if (C - d > 0) {
212 Vec data_bvec = Vec::loadu(X_ptr + d, C - d);
213 fVec mean_fvec0 = fVec::loadu(mean_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
214 fVec mean_fvec1 = fVec::loadu(mean_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
215 fVec rstd_fvec0 = fVec::loadu(rstd_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
216 fVec rstd_fvec1 = fVec::loadu(rstd_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
217 auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
218 mean_fvec0 = data_fvec0 + mean_fvec0;
219 mean_fvec1 = data_fvec1 + mean_fvec1;
220 rstd_fvec0 = data_fvec0 * data_fvec0 + rstd_fvec0;
221 rstd_fvec1 = data_fvec1 * data_fvec1 + rstd_fvec1;
222 mean_fvec0.store(mean_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
223 mean_fvec1.store(mean_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
224 rstd_fvec0.store(rstd_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
225 rstd_fvec1.store(rstd_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
226 }
227 }
228
229 template <typename T, typename opmath_t>
230 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ApplyScaleBias(T * Y_ptr,const T * X_ptr,const opmath_t * scale_ptr,const opmath_t * bias_ptr,int64_t C)231 ApplyScaleBias(
232 T* Y_ptr,
233 const T* X_ptr,
234 const opmath_t* scale_ptr,
235 const opmath_t* bias_ptr,
236 int64_t C) {
237 using Vec = vec::Vectorized<T>;
238 vec::map3<T>(
239 [](Vec x, Vec scale, Vec bias) { return x * scale + bias; },
240 Y_ptr,
241 X_ptr,
242 scale_ptr,
243 bias_ptr,
244 C);
245 }
246
247 // std::is_same<T, at::BFloat16> || std::is_same<T, at::Half>
248 template <typename T, typename opmath_t>
249 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ApplyScaleBias(T * Y_ptr,const T * X_ptr,const opmath_t * scale_ptr,const opmath_t * bias_ptr,int64_t C)250 ApplyScaleBias(
251 T* Y_ptr,
252 const T* X_ptr,
253 const opmath_t* scale_ptr,
254 const opmath_t* bias_ptr,
255 int64_t C) {
256 using fVec = vec::Vectorized<opmath_t>;
257 using Vec = vec::Vectorized<T>;
258 int64_t d = 0;
259 for (; d < C - (C % Vec::size()); d += Vec::size()) {
260 Vec data_bvec = Vec::loadu(X_ptr + d);
261 fVec scale_fvec0 = fVec::loadu(scale_ptr + d);
262 fVec scale_fvec1 = fVec::loadu(scale_ptr + d + fVec::size());
263 fVec bias_fvec0 = fVec::loadu(bias_ptr + d);
264 fVec bias_fvec1 = fVec::loadu(bias_ptr + d + fVec::size());
265 auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
266 fVec out0 = data_fvec0 * scale_fvec0 + bias_fvec0;
267 fVec out1 = data_fvec1 * scale_fvec1 + bias_fvec1;
268 convert_from_float<T>(out0, out1).store(Y_ptr + d);
269 }
270 if (C - d > 0) {
271 Vec data_bvec = Vec::loadu(X_ptr + d, C - d);
272 fVec scale_fvec0 = fVec::loadu(scale_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
273 fVec scale_fvec1 = fVec::loadu(scale_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
274 fVec bias_fvec0 = fVec::loadu(bias_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
275 fVec bias_fvec1 = fVec::loadu(bias_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
276 auto [data_fvec0, data_fvec1] = convert_to_float<T>(data_bvec);
277 fVec out0 = data_fvec0 * scale_fvec0 + bias_fvec0;
278 fVec out1 = data_fvec1 * scale_fvec1 + bias_fvec1;
279 convert_from_float<T>(out0, out1).store(Y_ptr + d, C - d);
280 }
281 }
282
283 template <typename T, typename PT>
GroupNormKernelImplChannelsLastInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)284 void GroupNormKernelImplChannelsLastInternal(
285 const Tensor& X,
286 const Tensor& gamma,
287 const Tensor& beta,
288 int64_t N,
289 int64_t C,
290 int64_t HxW,
291 int64_t group,
292 double eps,
293 Tensor& Y,
294 Tensor& mean,
295 Tensor& rstd) {
296 TORCH_CHECK(X.numel() == N * C * HxW);
297 TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
298 TORCH_CHECK(!beta.defined() || beta.numel() == C);
299 const int64_t G = group;
300 const int64_t D = C / G;
301 const T* X_data = X.const_data_ptr<T>();
302 const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
303 const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
304 T* Y_data = Y.data_ptr<T>();
305 PT* mean_data = mean.data_ptr<PT>();
306 PT* rstd_data = rstd.data_ptr<PT>();
307
308 using opmath_t = at::opmath_type<T>;
309
310 const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
311 const bool gamma_null = (gamma_data == nullptr);
312 const bool beta_null = beta_data == nullptr;
313
314 // NB: About algorithm choosen:
315 //
316 // On channels last, GroupNorm has a input shape of {N, H, W, GD},
317 // Mean and rstd are collected per each n and g, which involves reduction
318 // on non-adjacent dimensions. We can parallel in the following 2 impls:
319 //
320 // impl-1: parallel on N * G. Only need one omp session but memory access
321 // per thread is non-contiguous.
322 //
323 // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
324 // but requires help of extra temp buffer of size {T, N, 2C}.
325 //
326 // Generally impl-2 has better performance when HxW is large enough, so that
327 // data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
328 //
329 constexpr int64_t feature_map_threshold = 1024;
330 if (HxW < feature_map_threshold) {
331 // impl-1: parallel on N * G.
332 //
333 // for each plain of HxW, scale and bias is calculated only once
334 Tensor buffer = at::empty({N * G, 2 * D}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
335 opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
336
337 at::parallel_for(0, N * G, 1, [&](int64_t begin, int64_t end) {
338 int64_t n{0}, g{0};
339 data_index_init(begin, n, N, g, G);
340 for (const auto i : c10::irange(begin, end)) {
341 // step-1: for each n and g, collect sum of x and x2
342 //
343 // Note that using vec::map_reduce_all here is simpler to write
344 // but it is slower since horizontal reduce from vec to scalar is slow.
345 // So it is better to reduce with a vec across all HxW plain,
346 // and do a horizontal add just once for each {n, g}.
347 //
348 auto [mean_val, rstd_val] = ColumnwiseMoments(
349 X_data + n * HxW * C + g * D,
350 HxW,
351 C,
352 D);
353
354 mean_val *= s;
355 rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
356 rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
357 mean_data[i] = mean_val;
358 rstd_data[i] = rstd_val;
359
360 // step-2: calculate scale and bias
361 opmath_t* scale_ptr = buffer_data + i * 2 * D;
362 opmath_t* bias_ptr = scale_ptr + D;
363 for (const auto d : c10::irange(D)) {
364 const int64_t c = g * D + d;
365 scale_ptr[d] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
366 bias_ptr[d] = -scale_ptr[d] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
367 }
368
369 // step-3: apply scale and bias
370 for (const auto m : c10::irange(HxW)) {
371 const T* X_ptr = X_data + n * HxW * C + m * C + g * D;
372 T* Y_ptr = Y_data + n * HxW * C + m * C + g * D;
373 ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, D);
374 }
375
376 data_index_step(n, N, g, G);
377 }
378 });
379 } else {
380 // impl-2: parallel on N * HxW.
381 //
382 // temp buffer holding x and x2
383 int num_threads = at::get_num_threads();
384 Tensor buffer = at::empty({num_threads, N, 2 * C},
385 X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
386 opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
387 Tensor tmp_buffer = at::empty({N, 2 * G},
388 X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
389 opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();
390 // step-1: accumulate on dimension of C
391 //
392 // In order to improve multi-core performance when N=1,
393 // we parallel on the all the outer dimensions of N and HxW,
394 // leaving the most inner dimension C for vectorization.
395 //
396 // Note that parallel on {N, HxW, G} is not feasible for some common configs,
397 // e.g. say input shape is {1, 32, h, w} and G = 8,
398 // this will give D = 4 which is unable to take full SIMD length.
399 //
400 // To avoid thread conflict, we make use of a temp buffer of {T, N, 2C},
401 // firstly, reduce from {N, HxW, C} to {T, N, 2C}
402 //
403 at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
404 int tid = at::get_thread_num();
405 opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;
406
407 int64_t n{0}, m{0};
408 data_index_init(begin, n, N, m, HxW);
409 for (const auto i : c10::irange(begin, end)) {
410 opmath_t* mean_ptr = buffer_ptr + n * 2 * C;
411 opmath_t* rstd_ptr = mean_ptr + C;
412 const T* X_ptr = X_data + i * C;
413 CalcMeanVar<T, opmath_t>(X_ptr, mean_ptr, rstd_ptr, C);
414 data_index_step(n, N, m, HxW);
415 }
416 });
417
418 // step-2: compute mean and rstd
419 for (const auto n : c10::irange(N)) {
420 for (const auto g : c10::irange(G)) {
421 opmath_t mean_val{0}, rstd_val{0};
422 for (const auto d : c10::irange(D)) {
423 for (const auto t : c10::irange(num_threads)) {
424 opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
425 mean_val += buffer_ptr[g * D + d];
426 rstd_val += buffer_ptr[g * D + d + C];
427 }
428 }
429 mean_val *= s;
430 rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
431 rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
432 tmp_buffer_data[n * 2 * G + 2 * g] = mean_val;
433 tmp_buffer_data[n * 2 * G + 2 * g + 1] = rstd_val;
434 }
435 }
436
437 // step-3: compute scale and bias
438 //
439 // mean/rstd have shape of {N, G}, gamma/beta have shape of {G, D}.
440 // And scale/bias have shape of {N, C} so that we can directly vectorize on
441 // dimension of C in the final step.
442 //
443 // We could fuse step 3 and 4 into a single session but this way is better:
444 // a. D might be too small for vectorization;
445 // b. Avoid duplicate calculation of scale/bias, each HxW plain share the same scale/bias
446 //
447 for (const auto n : c10::irange(N)) {
448 for (const auto g : c10::irange(G)) {
449 opmath_t* scale_ptr = buffer_data + n * 2 * C;
450 opmath_t* bias_ptr = scale_ptr + C;
451 opmath_t mean_val = tmp_buffer_data[n * 2 * G + 2 * g];
452 opmath_t rstd_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];
453 mean_data[n * G + g] = mean_val;
454 rstd_data[n * G + g] = rstd_val;
455
456 for (const auto d : c10::irange(D)) {
457 const int64_t c = g * D + d;
458 scale_ptr[c] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
459 bias_ptr[c] = -scale_ptr[c] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
460 }
461 }
462 }
463
464 // step-4: apply scale and bias
465 //
466 // Parallel on on the all the outer dimensions of N and HxW
467 // and vectorize on C.
468 //
469 at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
470 int64_t n{0}, m{0};
471 data_index_init(begin, n, N, m, HxW);
472 for (const auto i : c10::irange(begin, end)) {
473 const T* X_ptr = X_data + i * C;
474 T* Y_ptr = Y_data + i * C;
475 opmath_t* scale_ptr = buffer_data + n * 2 * C;
476 opmath_t* bias_ptr = scale_ptr + C;
477 ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, C);
478 data_index_step(n, N, m, HxW);
479 }
480 });
481 }
482 }
483
GroupNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & Y,Tensor & mean,Tensor & rstd)484 void GroupNormKernelImpl(
485 const Tensor& X,
486 const Tensor& gamma,
487 const Tensor& beta,
488 int64_t N,
489 int64_t C,
490 int64_t HxW,
491 int64_t group,
492 double eps,
493 Tensor& Y,
494 Tensor& mean,
495 Tensor& rstd) {
496 const bool mixed_type = is_mixed_type(X, gamma, beta);
497 switch (X.suggest_memory_format()) {
498 case at::MemoryFormat::Contiguous: {
499 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormKernelImpl", [&]() {
500 using param_t = at::opmath_type<scalar_t>;
501 if (mixed_type) {
502 GroupNormKernelImplInternal<scalar_t, param_t>(
503 X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
504 } else {
505 GroupNormKernelImplInternal<scalar_t, scalar_t>(
506 X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
507 }
508 });
509 break;
510 }
511 case at::MemoryFormat::ChannelsLast:
512 case at::MemoryFormat::ChannelsLast3d: {
513 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormKernelImpl", [&]() {
514 using param_t = at::opmath_type<scalar_t>;
515 if (mixed_type) {
516 GroupNormKernelImplChannelsLastInternal<scalar_t, param_t>(
517 X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
518 } else {
519 GroupNormKernelImplChannelsLastInternal<scalar_t, scalar_t>(
520 X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
521 }
522 });
523 break;
524 }
525 default:
526 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, ChannelsLast3d, Contiguous");
527 }
528 }
529
530
531 template <typename T, typename opmath_t>
532 typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ComputeInternalGradients(int64_t N,int64_t C,int64_t HxW,const T * dY,const T * X,opmath_t * ds,opmath_t * db)533 ComputeInternalGradients(
534 int64_t N,
535 int64_t C,
536 int64_t HxW,
537 const T* dY,
538 const T* X,
539 opmath_t* ds,
540 opmath_t* db) {
541 using Vec = at::vec::Vectorized<opmath_t>;
542 at::parallel_for(0, N * C, 1, [=](int64_t start, int64_t end) {
543 for (const auto i : c10::irange(start, end)) {
544 const T* dY_ptr = dY + i * HxW;
545 const T* X_ptr = X + i * HxW;
546 ds[i] = at::vec::map2_reduce_all<T>(
547 [](Vec x, Vec y) { return x * y; },
548 [](Vec x, Vec y) { return x + y; },
549 dY_ptr,
550 X_ptr,
551 HxW);
552 db[i] = at::vec::reduce_all<T>(
553 [](Vec& x, Vec& y) { return x + y; }, dY_ptr, HxW);
554 }
555 });
556 }
557
558 template <typename T, typename opmath_t>
559 typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ComputeInternalGradients(int64_t N,int64_t C,int64_t HxW,const T * dY,const T * X,opmath_t * ds,opmath_t * db)560 ComputeInternalGradients(
561 int64_t N,
562 int64_t C,
563 int64_t HxW,
564 const T* dY,
565 const T* X,
566 opmath_t* ds,
567 opmath_t* db) {
568 using Vec = vec::Vectorized<T>;
569 using fVec = vec::Vectorized<opmath_t>;
570 at::parallel_for(0, N * C, 1, [=](int64_t start, int64_t end) {
571 constexpr int64_t K = Vec::size();
572 const int64_t inner_size = HxW / K * K;
573 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
574 std::array<opmath_t, K / 2> ds_arr;
575 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
576 std::array<opmath_t, K / 2> db_arr;
577 for (const auto i : c10::irange(start, end)) {
578 const T* dY_ptr = dY + i * HxW;
579 const T* X_ptr = X + i * HxW;
580 fVec ds_vec(0);
581 fVec db_vec(0);
582 for (int64_t j = 0; j < inner_size; j += K) {
583 const Vec dy_bvec = Vec::loadu(dY_ptr + j);
584 const Vec x_bvec = Vec::loadu(X_ptr + j);
585 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
586 auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
587 ds_vec = ds_vec + dy_fvec0 * x_fvec0;
588 ds_vec = ds_vec + dy_fvec1 * x_fvec1;
589 db_vec = db_vec + dy_fvec0 + dy_fvec1;
590 }
591 ds_vec.store(ds_arr.data());
592 db_vec.store(db_arr.data());
593 opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
594 opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
595 for (const auto j : c10::irange(inner_size, HxW)) {
596 ds_val += opmath_t(dY_ptr[j]) * opmath_t(X_ptr[j]);
597 db_val += opmath_t(dY_ptr[j]);
598 }
599 ds[i] = ds_val;
600 db[i] = db_val;
601 }
602 });
603 }
604
605 template <typename PT, typename opmath_t>
606 inline typename std::enable_if<std::is_same<PT, opmath_t>::value, void>::type
CalcDsDb(const opmath_t * ds_ptr,const opmath_t * db_ptr,const PT * gamma_ptr,const int64_t d,const int64_t K,void * ds_arr,void * db_arr)607 CalcDsDb(
608 const opmath_t* ds_ptr,
609 const opmath_t* db_ptr,
610 const PT* gamma_ptr,
611 const int64_t d,
612 const int64_t K,
613 void* ds_arr,
614 void* db_arr) {
615 vec::Vectorized<opmath_t> ds_vec(0);
616 vec::Vectorized<opmath_t> db_vec(0);
617 for (int64_t j = 0; j < d; j += K) {
618 const vec::Vectorized<PT> gamma_vec = (gamma_ptr == nullptr)
619 ? vec::Vectorized<PT>(1)
620 : vec::Vectorized<PT>::loadu(gamma_ptr + j);
621 ds_vec = ds_vec + vec::Vectorized<PT>::loadu(ds_ptr + j) * gamma_vec;
622 db_vec = db_vec + vec::Vectorized<PT>::loadu(db_ptr + j) * gamma_vec;
623 }
624 ds_vec.store(ds_arr);
625 db_vec.store(db_arr);
626 }
627
628 template <typename PT, typename opmath_t>
629 inline typename std::enable_if<!std::is_same<PT, opmath_t>::value, void>::type
CalcDsDb(const opmath_t * ds_ptr,const opmath_t * db_ptr,const PT * gamma_ptr,const int64_t d,const int64_t K,void * ds_arr,void * db_arr)630 CalcDsDb(
631 const opmath_t* ds_ptr,
632 const opmath_t* db_ptr,
633 const PT* gamma_ptr,
634 const int64_t d,
635 const int64_t K,
636 void* ds_arr,
637 void* db_arr) {
638 using fVec = at::vec::Vectorized<opmath_t>;
639 using Vec = at::vec::Vectorized<PT>;
640 fVec ds_acc(0);
641 fVec db_acc(0);
642 for (int64_t j = 0; j < d; j += K) {
643 const Vec gamma_vec = (gamma_ptr == nullptr) ? Vec(1) : Vec::loadu(gamma_ptr + j);
644 auto [gamma_vec0, gamma_vec1] = convert_to_float<PT>(gamma_vec);
645 ds_acc += fVec::loadu(ds_ptr + j) * gamma_vec0;
646 ds_acc += fVec::loadu(ds_ptr + j + fVec::size()) * gamma_vec1;
647 db_acc += fVec::loadu(db_ptr + j) * gamma_vec0;
648 db_acc += fVec::loadu(db_ptr + j + fVec::size()) * gamma_vec1;
649 }
650 ds_acc.store(ds_arr);
651 db_acc.store(db_arr);
652 }
653
654 template <typename T, typename PT, typename opmath_t>
GroupNormInputBackward(int64_t N,int64_t C,int64_t HxW,int64_t group,const T * dY,const T * X,const PT * mean,const PT * rstd,const PT * gamma,const opmath_t * ds,const opmath_t * db,T * dX)655 void GroupNormInputBackward(
656 int64_t N,
657 int64_t C,
658 int64_t HxW,
659 int64_t group,
660 const T* dY,
661 const T* X,
662 const PT* mean,
663 const PT* rstd,
664 const PT* gamma,
665 const opmath_t* ds,
666 const opmath_t* db,
667 T* dX) {
668 const int64_t G = group;
669 const int64_t D = C / G;
670 const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
671 const bool gamma_null = (gamma == nullptr);
672 at::parallel_for(0, N * G, 1, [=](int64_t start, int64_t end) {
673 constexpr int64_t K = vec::Vectorized<PT>::size();
674 const int64_t d = D / K * K;
675 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
676 std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> ds_arr;
677 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
678 std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> db_arr;
679 for (const auto i : c10::irange(start, end)) {
680 const int64_t g = i % G;
681 const opmath_t* ds_ptr = ds + i * D;
682 const opmath_t* db_ptr = db + i * D;
683 const PT* gamma_ptr = gamma_null ? nullptr : (gamma + g * D);
684 CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
685 opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
686 opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
687 for (const auto j : c10::irange(d, D)) {
688 const opmath_t gamma_v = gamma_null ? opmath_t(1) : opmath_t(gamma[g * D + j]);
689 ds_val += ds_ptr[j] * gamma_v;
690 db_val += db_ptr[j] * gamma_v;
691 }
692 const opmath_t c2 =
693 (db_val * opmath_t(mean[i]) - ds_val) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * s;
694 const opmath_t c3 = -c2 * opmath_t(mean[i]) - db_val * opmath_t(rstd[i]) * s;
695
696 for (const auto j : c10::irange(D)) {
697 const int64_t c = g * D + j;
698 const T* dY_ptr = dY + (i * D + j) * HxW;
699 const T* X_ptr = X + (i * D + j) * HxW;
700 T* dX_ptr = dX + (i * D + j) * HxW;
701 const opmath_t c1 = opmath_t(rstd[i]) * (gamma_null ? opmath_t(1) : opmath_t(gamma[c]));
702 for (const auto k : c10::irange(HxW)) {
703 dX_ptr[k] = c1 * opmath_t(dY_ptr[k]) + c2 * opmath_t(X_ptr[k]) + c3;
704 }
705 }
706 }
707 });
708 }
709
710 template <typename PT, typename opmath_t>
711 typename std::enable_if<std::is_same<PT, opmath_t>::value, void>::type
GammaBackward(int64_t N,int64_t C,int64_t group,const PT * mean,const PT * rstd,const opmath_t * ds,const opmath_t * db,PT * dgamma)712 GammaBackward(
713 int64_t N,
714 int64_t C,
715 int64_t group,
716 const PT* mean,
717 const PT* rstd,
718 const opmath_t* ds,
719 const opmath_t* db,
720 PT* dgamma) {
721 const int64_t G = group;
722 const int64_t D = C / G;
723 constexpr int64_t K = at::vec::Vectorized<PT>::size();
724 using Vec = at::vec::Vectorized<PT>;
725 const int64_t inner_size = D / K * K;
726 for (const auto g : c10::irange(G)) {
727 int64_t i = 0;
728 for (; i < inner_size; i += K) {
729 Vec acc_vec{0};
730 for (const auto n : c10::irange(N)) {
731 const PT* ds_ptr = ds + n * C + g * D + i;
732 const PT* db_ptr = db + n * C + g * D + i;
733 auto ds_vec = Vec::loadu(ds_ptr);
734 auto db_vec = Vec::loadu(db_ptr);
735 auto mean_vec = Vec(mean[n * G + g]);
736 auto rstd_vec = Vec(rstd[n * G + g]);
737 acc_vec += (ds_vec - db_vec * mean_vec) * rstd_vec;
738 }
739 acc_vec.store(dgamma + g * D + i);
740 }
741 if (D - i > 0) {
742 Vec acc_vec{0};
743 for (const auto n : c10::irange(N)) {
744 const PT* ds_ptr = ds + n * C + g * D + i;
745 const PT* db_ptr = db + n * C + g * D + i;
746 auto ds_vec = Vec::loadu(ds_ptr, D - i);
747 auto db_vec = Vec::loadu(db_ptr, D - i);
748 auto mean_vec = Vec(mean[n * G + g]);
749 auto rstd_vec = Vec(rstd[n * G + g]);
750 acc_vec += (ds_vec - db_vec * mean_vec) * rstd_vec;
751 }
752 acc_vec.store(dgamma + g * D + i, D - i);
753 }
754 }
755 }
756
757 template <typename PT, typename opmath_t>
758 typename std::enable_if<!std::is_same<PT, opmath_t>::value, void>::type
GammaBackward(int64_t N,int64_t C,int64_t group,const PT * mean,const PT * rstd,const opmath_t * ds,const opmath_t * db,PT * dgamma)759 GammaBackward(
760 int64_t N,
761 int64_t C,
762 int64_t group,
763 const PT* mean,
764 const PT* rstd,
765 const opmath_t* ds,
766 const opmath_t* db,
767 PT* dgamma) {
768 const int64_t G = group;
769 const int64_t D = C / G;
770 using Vec = at::vec::Vectorized<PT>;
771 using fVec = at::vec::Vectorized<opmath_t>;
772 constexpr int64_t K = Vec::size();
773 const int64_t inner_size = D / K * K;
774 for (const auto g : c10::irange(G)) {
775 int64_t i = 0;
776 for (; i < inner_size; i += K) {
777 fVec acc0_vec{0}, acc1_vec{0};
778 for (const auto n : c10::irange(N)) {
779 const opmath_t* ds_ptr = ds + n * C + g * D + i;
780 const opmath_t* db_ptr = db + n * C + g * D + i;
781 fVec ds_vec0, ds_vec1, db_vec0, db_vec1;
782 ds_vec0 = fVec::loadu(ds_ptr);
783 ds_vec1 = fVec::loadu(ds_ptr + fVec::size());
784 db_vec0 = fVec::loadu(db_ptr);
785 db_vec1 = fVec::loadu(db_ptr + fVec::size());
786 fVec mean_vec = fVec(opmath_t(mean[n * G + g]));
787 fVec rstd_vec = fVec(opmath_t(rstd[n * G + g]));
788 acc0_vec += (ds_vec0 - db_vec0 * mean_vec) * rstd_vec;
789 acc1_vec += (ds_vec1 - db_vec1 * mean_vec) * rstd_vec;
790 }
791 convert_from_float<PT>(acc0_vec, acc1_vec).store(dgamma + g * D + i);
792 }
793 if (D - i > 0) {
794 fVec acc0_vec{0}, acc1_vec{0};
795 for (const auto n : c10::irange(N)) {
796 const opmath_t* ds_ptr = ds + n * C + g * D + i;
797 const opmath_t* db_ptr = db + n * C + g * D + i;
798 fVec ds_vec0, ds_vec1, db_vec0, db_vec1;
799 ds_vec0 = fVec::loadu(
800 ds_ptr, (D - i) > fVec::size() ? fVec::size() : (D - i));
801 ds_vec1 = fVec::loadu(
802 ds_ptr + fVec::size(),
803 (D - i) > fVec::size() ? (D - i - fVec::size()) : 0);
804 db_vec0 = fVec::loadu(
805 db_ptr, (D - i) > fVec::size() ? fVec::size() : (D - i));
806 db_vec1 = fVec::loadu(
807 db_ptr + fVec::size(),
808 (D - i) > fVec::size() ? (D - i - fVec::size()) : 0);
809 fVec mean_vec = fVec(opmath_t(mean[n * G + g]));
810 fVec rstd_vec = fVec(opmath_t(rstd[n * G + g]));
811 acc0_vec += (ds_vec0 - db_vec0 * mean_vec) * rstd_vec;
812 acc1_vec += (ds_vec1 - db_vec1 * mean_vec) * rstd_vec;
813 }
814 convert_from_float<PT>(acc0_vec, acc1_vec).store(dgamma + g * D + i, D - i);
815 }
816 }
817 }
818
819 template <typename PT, typename opmath_t>
820 typename std::enable_if<std::is_same<PT, opmath_t>::value, void>::type
BetaBackward(int64_t N,int64_t C,const opmath_t * db,PT * dbeta)821 BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) {
822 using Vec = at::vec::Vectorized<PT>;
823 constexpr int64_t K = Vec::size();
824 Vec acc_vec{0}, zero{0};
825 const int64_t inner_size = C / K * K;
826 int64_t i = 0;
827 for (; i < inner_size; i += K) {
828 for (const auto n : c10::irange(N)) {
829 acc_vec += Vec::loadu(db + n * C + i);
830 }
831 acc_vec.store(dbeta + i);
832 acc_vec = Vec::set(acc_vec, zero);
833 }
834 if (C - i > 0) {
835 for (const auto n : c10::irange(N)) {
836 acc_vec += Vec::loadu(db + n * C + i, C - i);
837 }
838 acc_vec.store(dbeta + i, C - i);
839 acc_vec = Vec::set(acc_vec, zero, C - i);
840 }
841 }
842
843 template <typename PT, typename opmath_t>
844 typename std::enable_if<!std::is_same<PT, opmath_t>::value, void>::type
BetaBackward(int64_t N,int64_t C,const opmath_t * db,PT * dbeta)845 BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) {
846 using Vec = at::vec::Vectorized<PT>;
847 using fVec = at::vec::Vectorized<opmath_t>;
848 constexpr int64_t K = Vec::size();
849 fVec acc0_vec{0}, acc1_vec{0}, zero{0};
850 const int64_t inner_size = C / K * K;
851 int64_t i = 0;
852 for (; i < inner_size; i += K) {
853 for (const auto n : c10::irange(N)) {
854 fVec db_vec0, db_vec1;
855 db_vec0 = fVec::loadu(db + n * C + i);
856 db_vec1 = fVec::loadu(db + n * C + i + fVec::size());
857 acc0_vec += db_vec0;
858 acc1_vec += db_vec1;
859 }
860 convert_from_float<PT>(acc0_vec, acc1_vec).store(dbeta + i);
861 acc0_vec = fVec::set(acc0_vec, zero);
862 acc1_vec = fVec::set(acc1_vec, zero);
863 }
864 if (C - i > 0) {
865 for (const auto n : c10::irange(N)) {
866 fVec db_vec0, db_vec1;
867 db_vec0 = fVec::loadu(
868 db + n * C + i, (C - i) > fVec::size() ? fVec::size() : (C - i));
869 db_vec1 = fVec::loadu(
870 db + n * C + i + fVec::size(),
871 (C - i) > fVec::size() ? (C - i - fVec::size()) : 0);
872 acc0_vec += db_vec0;
873 acc1_vec += db_vec1;
874 }
875 convert_from_float<PT>(acc0_vec, acc1_vec).store(dbeta + i, C - i);
876 acc0_vec = fVec::set(acc0_vec, zero, C - i);
877 acc1_vec = fVec::set(acc1_vec, zero, C - i);
878 }
879 }
880
881 template <typename T, typename PT>
GroupNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)882 void GroupNormBackwardKernelImplInternal(
883 const Tensor& dY,
884 const Tensor& X,
885 const Tensor& mean,
886 const Tensor& rstd,
887 const Tensor& gamma,
888 int64_t N,
889 int64_t C,
890 int64_t HxW,
891 int64_t group,
892 Tensor& dX,
893 Tensor& dgamma,
894 Tensor& dbeta) {
895 TORCH_CHECK(dY.numel() == N * C * HxW);
896 TORCH_CHECK(X.numel() == N * C * HxW);
897 TORCH_CHECK(mean.numel() == N * group);
898 TORCH_CHECK(rstd.numel() == N * group);
899 TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
900 const T* dY_data = dY.const_data_ptr<T>();
901 const T* X_data = X.const_data_ptr<T>();
902 const PT* mean_data = mean.const_data_ptr<PT>();
903 const PT* rstd_data = rstd.const_data_ptr<PT>();
904 const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
905 T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
906 PT* dgamma_data = dgamma.defined() ? dgamma.data_ptr<PT>() : nullptr;
907 PT* dbeta_data = dbeta.defined() ? dbeta.data_ptr<PT>() : nullptr;
908 using opmath_t = at::opmath_type<T>;
909 Tensor ds = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
910 Tensor db = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
911 opmath_t* ds_data = ds.data_ptr<opmath_t>();
912 opmath_t* db_data = db.data_ptr<opmath_t>();
913 ComputeInternalGradients<T, opmath_t>(N, C, HxW, dY_data, X_data, ds_data, db_data);
914
915 if (dX_data != nullptr) {
916 GroupNormInputBackward<T, PT, opmath_t>(
917 N,
918 C,
919 HxW,
920 group,
921 dY_data,
922 X_data,
923 mean_data,
924 rstd_data,
925 gamma_data,
926 ds_data,
927 db_data,
928 dX_data);
929 }
930 if (dgamma_data != nullptr) {
931 GammaBackward(
932 N, C, group, mean_data, rstd_data, ds_data, db_data, dgamma_data);
933 }
934 if (dbeta_data != nullptr) {
935 BetaBackward(N, C, db_data, dbeta_data);
936 }
937 }
938
939 template <typename T, typename opmath_t>
940 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
DsDbRowwiseMomentsChannelsLast(const T * dY_ptr,const T * X_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t C)941 DsDbRowwiseMomentsChannelsLast(
942 const T* dY_ptr,
943 const T* X_ptr,
944 opmath_t* ds_ptr,
945 opmath_t* db_ptr,
946 int64_t C) {
947 using Vec = vec::Vectorized<T>;
948 constexpr int64_t K = vec::Vectorized<T>::size();
949 const int64_t inner_size = C / K * K;
950 int64_t d = 0;
951 for (; d < inner_size; d += K) {
952 Vec ds_dev = Vec::loadu(ds_ptr + d);
953 Vec db_vec = Vec::loadu(db_ptr + d);
954 Vec x_vec = Vec::loadu(X_ptr + d);
955 Vec dy_vec = Vec::loadu(dY_ptr + d);
956
957 ds_dev += x_vec * dy_vec;
958 db_vec += dy_vec;
959 ds_dev.store(ds_ptr + d);
960 db_vec.store(db_ptr + d);
961 }
962 if (C - d > 0) {
963 Vec ds_dev = Vec::loadu(ds_ptr + d, C - d);
964 Vec db_vec = Vec::loadu(db_ptr + d, C - d);
965 Vec x_vec = Vec::loadu(X_ptr + d, C - d);
966 Vec dy_vec = Vec::loadu(dY_ptr + d, C - d);
967 ds_dev += x_vec * dy_vec;
968 db_vec += dy_vec;
969 ds_dev.store(ds_ptr + d, C - d);
970 db_vec.store(db_ptr + d, C - d);
971 }
972 }
973
974 template <typename T, typename opmath_t>
975 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
DsDbRowwiseMomentsChannelsLast(const T * dY_ptr,const T * X_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t C)976 DsDbRowwiseMomentsChannelsLast(
977 const T* dY_ptr,
978 const T* X_ptr,
979 opmath_t* ds_ptr,
980 opmath_t* db_ptr,
981 int64_t C) {
982 using fVec = vec::Vectorized<opmath_t>;
983 using Vec = vec::Vectorized<T>;
984 int64_t d = 0;
985 for (; d < C - (C % Vec::size()); d += Vec::size()) {
986 fVec ds_dev0 = fVec::loadu(ds_ptr + d);
987 fVec ds_dev1 = fVec::loadu(ds_ptr + d + fVec::size());
988 fVec db_vec0 = fVec::loadu(db_ptr + d);
989 fVec db_vec1 = fVec::loadu(db_ptr + d + fVec::size());
990 Vec x_vec = Vec::loadu(X_ptr + d);
991 Vec dy_vec = Vec::loadu(dY_ptr + d);
992 auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
993 auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
994 ds_dev0 += x_vec0 * dy_vec0;
995 ds_dev1 += x_vec1 * dy_vec1;
996 db_vec0 += dy_vec0;
997 db_vec1 += dy_vec1;
998
999 ds_dev0.store(ds_ptr + d);
1000 ds_dev1.store(ds_ptr + d + fVec::size());
1001 db_vec0.store(db_ptr + d);
1002 db_vec1.store(db_ptr + d + fVec::size());
1003
1004 }
1005 if (C - d > 0) {
1006 fVec ds_dev0 = fVec::loadu(ds_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1007 fVec ds_dev1 = fVec::loadu(ds_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1008 fVec db_vec0 = fVec::loadu(db_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1009 fVec db_vec1 = fVec::loadu(db_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1010 Vec x_vec = Vec::loadu(X_ptr + d, C - d);
1011 Vec dy_vec = Vec::loadu(dY_ptr + d, C - d);
1012 auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1013 auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1014 ds_dev0 += x_vec0 * dy_vec0;
1015 ds_dev1 += x_vec1 * dy_vec1;
1016 db_vec0 += dy_vec0;
1017 db_vec1 += dy_vec1;
1018
1019 ds_dev0.store(ds_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1020 ds_dev1.store(ds_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1021 db_vec0.store(db_ptr + d, (C - d) > fVec::size() ? fVec::size() : (C - d));
1022 db_vec1.store(db_ptr + d + fVec::size(), (C - d) > fVec::size() ? (C - d - fVec::size()) : 0);
1023 }
1024 }
1025
1026 template <typename T>
1027 inline typename std::enable_if<std::is_same<T, at::opmath_type<T>>::value,
1028 std::tuple<
1029 vec::Vectorized<T>,
1030 vec::Vectorized<T>>>::type
load_util(const T * data_ptr,int64_t n)1031 load_util(const T* data_ptr, int64_t n) {
1032 using Vec = vec::Vectorized<T>;
1033 auto vec0 = Vec::loadu(data_ptr, n > Vec::size() ? Vec::size() : n);
1034 auto vec1 = Vec::loadu(
1035 data_ptr + Vec::size(), n > Vec::size() ? (n - Vec::size()) : 0);
1036 return std::tuple<Vec, Vec>(vec0, vec1);
1037 }
1038
1039 template <typename T>
1040 inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value,
1041 std::tuple<
1042 vec::Vectorized<at::opmath_type<T>>,
1043 vec::Vectorized<at::opmath_type<T>>>
1044 >::type
load_util(const T * data_ptr,int64_t n)1045 load_util(const T* data_ptr, int64_t n) {
1046 using Vec = vec::Vectorized<T>;
1047 auto vec = Vec::loadu(data_ptr, n);
1048 return convert_to_float<T>(vec);
1049 }
1050
1051 template <typename T, typename PT, typename opmath_t>
1052 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastColMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1053 ApplyInputGradientsChannelsLastColMov(
1054 const T* dY_data,
1055 const T* X_data,
1056 T* dX_data,
1057 const PT* rstd,
1058 const PT* gamma,
1059 opmath_t c2,
1060 opmath_t c3,
1061 int64_t HxW,
1062 int64_t C,
1063 int64_t D) {
1064 const bool gamma_null = (gamma == nullptr);
1065 int64_t d = 0;
1066 auto K = vec::Vectorized<T>::size();
1067 for (; d < D / K * K; d += K) {
1068 auto c1 = vec::Vectorized<T>(*rstd) *
1069 (gamma_null ? vec::Vectorized<T>(1)
1070 : vec::Vectorized<T>::loadu(gamma + d));
1071 for (const auto m : c10::irange(HxW)) {
1072 const T* X_ptr = X_data + m * C;
1073 const T* dY_ptr = dY_data + m * C;
1074 T* dX_ptr = dX_data + m * C;
1075 auto dy_vec = vec::Vectorized<T>::loadu(dY_ptr + d);
1076 auto x_vec = vec::Vectorized<T>::loadu(X_ptr + d);
1077 auto dx_vec = c1 * dy_vec +
1078 vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1079 dx_vec.store(dX_ptr + d);
1080 }
1081 }
1082 if (D - d > 0) {
1083 auto c1 = vec::Vectorized<T>(*rstd) *
1084 (gamma_null ? vec::Vectorized<T>(1)
1085 : vec::Vectorized<T>::loadu(gamma + d, D - d));
1086 for (const auto m : c10::irange(HxW)) {
1087 const T* X_ptr = X_data + m * C;
1088 const T* dY_ptr = dY_data + m * C;
1089 T* dX_ptr = dX_data + m * C;
1090 auto dy_vec = vec::Vectorized<T>::loadu(dY_ptr + d, D - d);
1091 auto x_vec = vec::Vectorized<T>::loadu(X_ptr + d, D - d);
1092 auto dx_vec = c1 * dy_vec +
1093 vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1094 dx_vec.store(dX_ptr + d, D - d);
1095 }
1096 }
1097 }
1098
1099 template <typename T, typename PT, typename opmath_t>
1100 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastColMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1101 ApplyInputGradientsChannelsLastColMov(
1102 const T* dY_data,
1103 const T* X_data,
1104 T* dX_data,
1105 const PT* rstd,
1106 const PT* gamma,
1107 opmath_t c2,
1108 opmath_t c3,
1109 int64_t HxW,
1110 int64_t C,
1111 int64_t D) {
1112 using Vec = vec::Vectorized<T>;
1113 using fVec = vec::Vectorized<opmath_t>;
1114 const bool gamma_null = (gamma == nullptr);
1115 auto K = Vec::size();
1116 int64_t d = 0;
1117 for (; d < D / K * K; d += K) {
1118 auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1119 : load_util(gamma + d, K);
1120 c1_0 = c1_0 * fVec(opmath_t(*rstd));
1121 c1_1 = c1_1 * fVec(opmath_t(*rstd));
1122 for (const auto m : c10::irange(HxW)) {
1123 const T* X_ptr = X_data + m * C;
1124 const T* dY_ptr = dY_data + m * C;
1125 T* dX_ptr = dX_data + m * C;
1126
1127 Vec dy_vec = Vec::loadu(dY_ptr + d);
1128 Vec x_vec = Vec::loadu(X_ptr + d);
1129 auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1130 auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1131 fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1132 fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1133 convert_from_float<T>(dx_vec0, dx_vec1).store(dX_ptr + d);
1134 }
1135 }
1136 if (D - d > 0) {
1137 auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1138 : load_util(gamma + d, D - d);
1139 c1_0 = c1_0 * fVec(opmath_t(*rstd));
1140 c1_1 = c1_1 * fVec(opmath_t(*rstd));
1141 for (const auto m : c10::irange(HxW)) {
1142 const T* X_ptr = X_data + m * C;
1143 const T* dY_ptr = dY_data + m * C;
1144 T* dX_ptr = dX_data + m * C;
1145 Vec dy_vec = Vec::loadu(dY_ptr + d, D - d);
1146 Vec x_vec = Vec::loadu(X_ptr + d, D - d);
1147 auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1148 auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1149 fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1150 fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1151 convert_from_float<T>(dx_vec0, dx_vec1).store(dX_ptr + d, D - d);
1152 }
1153 }
1154 }
1155
1156 template <typename T, typename PT, typename opmath_t>
1157 inline typename std::enable_if<std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastRowMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1158 ApplyInputGradientsChannelsLastRowMov(
1159 const T* dY_data,
1160 const T* X_data,
1161 T* dX_data,
1162 const PT* rstd,
1163 const PT* gamma,
1164 opmath_t c2,
1165 opmath_t c3,
1166 int64_t HxW,
1167 int64_t C,
1168 int64_t D) {
1169 const bool gamma_null = (gamma == nullptr);
1170 int64_t d = 0;
1171 auto K = vec::Vectorized<T>::size();
1172 for (; d < D / K * K; d += K) {
1173 auto c1 = vec::Vectorized<T>(*rstd) *
1174 (gamma_null ? vec::Vectorized<T>(1) : vec::Vectorized<T>::loadu(gamma + d));
1175 auto dy_vec = vec::Vectorized<T>::loadu(dY_data + d);
1176 auto x_vec = vec::Vectorized<T>::loadu(X_data + d);
1177 auto dx_vec = c1 * dy_vec +
1178 vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1179 dx_vec.store(dX_data + d);
1180 }
1181 if (D - d > 0) {
1182 auto c1 = vec::Vectorized<T>(*rstd) *
1183 (gamma_null ? vec::Vectorized<T>(1) : vec::Vectorized<T>::loadu(gamma + d, D - d));
1184 auto dy_vec = vec::Vectorized<T>::loadu(dY_data + d, D - d);
1185 auto x_vec = vec::Vectorized<T>::loadu(X_data + d, D - d);
1186 auto dx_vec = c1 * dy_vec +
1187 vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
1188 dx_vec.store(dX_data + d, D - d);
1189 }
1190 }
1191
1192 template <typename T, typename PT, typename opmath_t>
1193 inline typename std::enable_if<!std::is_same<T, opmath_t>::value, void>::type
ApplyInputGradientsChannelsLastRowMov(const T * dY_data,const T * X_data,T * dX_data,const PT * rstd,const PT * gamma,opmath_t c2,opmath_t c3,int64_t HxW,int64_t C,int64_t D)1194 ApplyInputGradientsChannelsLastRowMov(
1195 const T* dY_data,
1196 const T* X_data,
1197 T* dX_data,
1198 const PT* rstd,
1199 const PT* gamma,
1200 opmath_t c2,
1201 opmath_t c3,
1202 int64_t HxW,
1203 int64_t C,
1204 int64_t D) {
1205 using Vec = vec::Vectorized<T>;
1206 using fVec = vec::Vectorized<opmath_t>;
1207 const bool gamma_null = (gamma == nullptr);
1208 auto K = Vec::size();
1209 int64_t d = 0;
1210 for (; d < D / K * K; d += K) {
1211 auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1212 : load_util(gamma + d, K);
1213 c1_0 = c1_0 * fVec(opmath_t(*rstd));
1214 c1_1 = c1_1 * fVec(opmath_t(*rstd));
1215 Vec dy_vec = Vec::loadu(dY_data + d);
1216 Vec x_vec = Vec::loadu(X_data + d);
1217 auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1218 auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1219 fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1220 fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1221 convert_from_float<T>(dx_vec0, dx_vec1).store(dX_data + d);
1222 }
1223 if (D - d > 0) {
1224 auto [c1_0, c1_1] = gamma_null ? std::tuple<fVec, fVec>(fVec(1), fVec(1))
1225 : load_util(gamma + d, D - d);
1226 c1_0 = c1_0 * fVec(opmath_t(*rstd));
1227 c1_1 = c1_1 * fVec(opmath_t(*rstd));
1228 Vec dy_vec = Vec::loadu(dY_data + d, D - d);
1229 Vec x_vec = Vec::loadu(X_data + d, D - d);
1230 auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1231 auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1232 fVec dx_vec0 = c1_0 * dy_vec0 + fVec(c2) * x_vec0 + fVec(c3);
1233 fVec dx_vec1 = c1_1 * dy_vec1 + fVec(c2) * x_vec1 + fVec(c3);
1234 convert_from_float<T>(dx_vec0, dx_vec1).store(dX_data + d, D - d);
1235 }
1236 }
1237
1238 template <typename T, typename PT, typename opmath_t>
1239 inline typename std::
1240 enable_if<std::is_same<T, opmath_t>::value, std::tuple<opmath_t, opmath_t>>::type
CalcInternalGradientsChannelsLast(const T * X_data,const T * dY_data,const PT * gamma_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t HxW,int64_t C,int64_t D)1241 CalcInternalGradientsChannelsLast(
1242 const T* X_data,
1243 const T* dY_data,
1244 const PT* gamma_ptr,
1245 opmath_t* ds_ptr,
1246 opmath_t* db_ptr,
1247 int64_t HxW,
1248 int64_t C,
1249 int64_t D) {
1250 using Vec = vec::Vectorized<T>;
1251 const bool gamma_null = (gamma_ptr == nullptr);
1252 constexpr int64_t K = Vec::size();
1253 const int64_t inner_size = D / K * K;
1254 int64_t d = 0;
1255 opmath_t ds_gamma{0}, db_gamma{0};
1256 for (; d < inner_size; d += K) {
1257 Vec acc0_vec{0}, acc1_vec{0};
1258 for (const auto m : c10::irange(HxW)) {
1259 const T* X_ptr = X_data + m * C;
1260 const T* dY_ptr = dY_data + m * C;
1261 Vec x_vec = Vec::loadu(X_ptr + d);
1262 Vec dy_vec = Vec::loadu(dY_ptr + d);
1263 acc0_vec += x_vec * dy_vec;
1264 acc1_vec += dy_vec;
1265 }
1266 acc0_vec.store(ds_ptr + d);
1267 acc1_vec.store(db_ptr + d);
1268 ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1269 acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
1270 db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1271 acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
1272 }
1273 if (D - d > 0) {
1274 Vec acc0_vec{0}, acc1_vec{0};
1275 for (const auto m : c10::irange(HxW)) {
1276 const T* X_ptr = X_data + m * C;
1277 const T* dY_ptr = dY_data + m * C;
1278 Vec x_vec = Vec::loadu(X_ptr + d, D - d);
1279 Vec dy_vec = Vec::loadu(dY_ptr + d, D - d);
1280 acc0_vec += x_vec * dy_vec;
1281 acc1_vec += dy_vec;
1282 }
1283 acc0_vec.store(ds_ptr + d, D - d);
1284 acc1_vec.store(db_ptr + d, D - d);
1285 ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1286 acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
1287 db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
1288 acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
1289 }
1290 return std::tuple<opmath_t, opmath_t>(ds_gamma, db_gamma);
1291 }
1292
1293 template <typename T, typename PT, typename opmath_t>
1294 inline typename std::
1295 enable_if<!std::is_same<T, opmath_t>::value, std::tuple<opmath_t, opmath_t>>::type
CalcInternalGradientsChannelsLast(const T * X_data,const T * dY_data,const PT * gamma_ptr,opmath_t * ds_ptr,opmath_t * db_ptr,int64_t HxW,int64_t C,int64_t D)1296 CalcInternalGradientsChannelsLast(
1297 const T* X_data,
1298 const T* dY_data,
1299 const PT* gamma_ptr,
1300 opmath_t* ds_ptr,
1301 opmath_t* db_ptr,
1302 int64_t HxW,
1303 int64_t C,
1304 int64_t D) {
1305 using Vec = vec::Vectorized<T>;
1306 using fVec = vec::Vectorized<opmath_t>;
1307 const bool gamma_null = (gamma_ptr == nullptr);
1308 constexpr int64_t K = Vec::size();
1309 const int64_t inner_size = D / K * K;
1310 float ds_gamma{0}, db_gamma{0};
1311 int64_t d = 0;
1312 for (; d < inner_size; d += K) {
1313 fVec acc0_vec0{0}, acc0_vec1{0}, acc1_vec0{0}, acc1_vec1{0};
1314 for (const auto m : c10::irange(HxW)) {
1315 const T* X_ptr = X_data + m * C;
1316 const T* dY_ptr = dY_data + m * C;
1317 Vec x_vec = Vec::loadu(X_ptr + d);
1318 Vec dy_vec = Vec::loadu(dY_ptr + d);
1319 auto [x_vec0, x_vec1] = convert_to_float<T>(x_vec);
1320 auto [dy_vec0, dy_vec1] = convert_to_float<T>(dy_vec);
1321 acc0_vec0 += x_vec0 * dy_vec0;
1322 acc0_vec1 += x_vec1 * dy_vec1;
1323 acc1_vec0 += dy_vec0;
1324 acc1_vec1 += dy_vec1;
1325 }
1326 acc0_vec0.store(ds_ptr + d);
1327 acc0_vec1.store(ds_ptr + d + fVec::size());
1328 acc1_vec0.store(db_ptr + d);
1329 acc1_vec1.store(db_ptr + d + fVec::size());
1330 auto [gamma_vec0, gamma_vec1] = gamma_null ?
1331 std::tuple<fVec, fVec>(fVec(1), fVec(1)) : load_util(gamma_ptr + d, K);
1332 ds_gamma += vec::vec_reduce_all(
1333 [](fVec& x, fVec& y) { return x + y; }, acc0_vec0 * gamma_vec0);
1334 ds_gamma += vec::vec_reduce_all(
1335 [](fVec& x, fVec& y) { return x + y; }, acc0_vec1 * gamma_vec1);
1336 db_gamma += vec::vec_reduce_all(
1337 [](fVec& x, fVec& y) { return x + y; }, acc1_vec0 * gamma_vec0);
1338 db_gamma += vec::vec_reduce_all(
1339 [](fVec& x, fVec& y) { return x + y; }, acc1_vec1 * gamma_vec1);
1340 }
1341 for (; d < D; d++) {
1342 opmath_t acc0{0}, acc1{0};
1343 for (const auto m : c10::irange(HxW)) {
1344 const T* X_ptr = X_data + m * C;
1345 const T* dY_ptr = dY_data + m * C;
1346 acc0 += opmath_t(X_ptr[d]) * opmath_t(dY_ptr[d]);
1347 acc1 += opmath_t(dY_ptr[d]);
1348 }
1349 ds_ptr[d] = acc0;
1350 db_ptr[d] = acc1;
1351 opmath_t gamma_val = gamma_null ? opmath_t(1) : opmath_t(gamma_ptr[d]);
1352 ds_gamma += acc0 * gamma_val;
1353 db_gamma += acc1 * gamma_val;
1354 }
1355
1356 return std::tuple<opmath_t, opmath_t>(ds_gamma, db_gamma);
1357 }
1358
1359 template <typename T, typename PT>
GroupNormBackwardKernelImplChannelsLastInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)1360 void GroupNormBackwardKernelImplChannelsLastInternal(
1361 const Tensor& dY,
1362 const Tensor& X,
1363 const Tensor& mean,
1364 const Tensor& rstd,
1365 const Tensor& gamma,
1366 int64_t N,
1367 int64_t C,
1368 int64_t HxW,
1369 int64_t group,
1370 Tensor& dX,
1371 Tensor& dgamma,
1372 Tensor& dbeta) {
1373 TORCH_CHECK(dY.numel() == N * C * HxW);
1374 TORCH_CHECK(X.numel() == N * C * HxW);
1375 TORCH_CHECK(mean.numel() == N * group);
1376 TORCH_CHECK(rstd.numel() == N * group);
1377 TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
1378 int64_t D = C / group;
1379 int64_t G = group;
1380 const T* dY_data = dY.const_data_ptr<T>();
1381 const T* X_data = X.const_data_ptr<T>();
1382 const PT* mean_data = mean.const_data_ptr<PT>();
1383 const PT* rstd_data = rstd.const_data_ptr<PT>();
1384 const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
1385 T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
1386 PT* dgamma_data = dgamma.defined() ? dgamma.data_ptr<PT>() : nullptr;
1387 PT* dbeta_data = dbeta.defined() ? dbeta.data_ptr<PT>() : nullptr;
1388 const bool gamma_null = (gamma_data == nullptr);
1389 using opmath_t = at::opmath_type<T>;
1390 Tensor ds = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
1391 Tensor db = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
1392 opmath_t* ds_data = ds.data_ptr<opmath_t>();
1393 opmath_t* db_data = db.data_ptr<opmath_t>();
1394 const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
1395
1396 // Similar to channels last forward, channels last backward has also 2 impls.
1397 // impl-1: parallel on N * G. Only need one omp session for input gradients
1398 // but memory access per thread is non-contiguous.
1399 //
1400 // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
1401 // but requires help of extra temp buffer of size {T, N, 2C}.
1402
1403 // Generally impl-2 has better performance when HxW is large enough, so that
1404 // data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
1405 constexpr int64_t feature_map_threshold = 2048;
1406 if (HxW < feature_map_threshold) {
1407 // impl-1: parallel on N * G.
1408 at::parallel_for(0, N * G, 1, [=](int64_t begin, int64_t end) {
1409 int64_t n{0}, g{0};
1410 data_index_init(begin, n, N, g, G);
1411 for (const auto i : c10::irange(begin, end)) {
1412 // Step 1. Compute internal gradients.
1413 opmath_t* ds_ptr = ds_data + i * D;
1414 opmath_t* db_ptr = db_data + i * D;
1415 const T* X_ptr = X_data + n * HxW * C + g * D;
1416 const T* dY_ptr = dY_data + n * HxW * C + g * D;
1417 const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
1418 auto [ds_gamma, db_gamma] = CalcInternalGradientsChannelsLast<T, PT, opmath_t>(
1419 X_ptr,
1420 dY_ptr,
1421 gamma_ptr,
1422 ds_ptr,
1423 db_ptr,
1424 HxW,
1425 C,
1426 D);
1427
1428 // Step 2. Compute dX.
1429 T* dX_ptr = dX_data + n * HxW * C + g * D;
1430 const PT* rstd_ptr = rstd_data + i;
1431 const opmath_t c2 = (db_gamma * opmath_t(mean_data[i]) - ds_gamma) *
1432 opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * s;
1433 const opmath_t c3 = -c2 * opmath_t(mean_data[i]) - db_gamma * opmath_t(rstd_data[i]) * s;
1434 ApplyInputGradientsChannelsLastColMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
1435 data_index_step(n, N, g, G);
1436 }
1437 });
1438
1439 } else {
1440 // impl-2: parallel on N * HxW.
1441 int num_threads = at::get_num_threads();
1442 Tensor buffer = at::empty({num_threads, N, 2 * C},
1443 X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
1444 opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
1445
1446 Tensor tmp_buffer = at::empty({N, 2 * G},
1447 X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
1448 opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();
1449
1450 // Step 1. Each thread compute their own internal gradients to the buffer.
1451 at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
1452 int tid = at::get_thread_num();
1453 opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;
1454 int64_t n{0}, m{0};
1455 data_index_init(begin, n, N, m, HxW);
1456 for (const auto i : c10::irange(begin, end)) {
1457 opmath_t* ds_ptr = buffer_ptr + n * 2 * C;
1458 opmath_t* db_ptr = ds_ptr + C;
1459 const T* X_ptr = X_data + i * C;
1460 const T* dY_ptr = dY_data + i * C;
1461
1462 DsDbRowwiseMomentsChannelsLast<T, opmath_t>(dY_ptr, X_ptr, ds_ptr, db_ptr, C);
1463 data_index_step(n, N, m, HxW);
1464 }
1465 });
1466
1467 // Step 2. Collect internal gradients from each thread and
1468 // get the final internal gradients to ds, db, and tmp_buffer.
1469 for (const auto n : c10::irange(N)) {
1470 for (const auto g : c10::irange(G)) {
1471 opmath_t ds_gamma{0}, db_gamma{0};
1472 for (const auto d : c10::irange(D)) {
1473 opmath_t ds_val{0}, db_val{0};
1474 for (const auto t : c10::irange(num_threads)) {
1475 opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
1476 opmath_t gamma_val = gamma_null ? opmath_t(1) : opmath_t(gamma_data[g * D + d]);
1477 ds_gamma += buffer_ptr[g * D + d] * gamma_val;
1478 db_gamma += buffer_ptr[g * D + d + C] * gamma_val;
1479 ds_val += buffer_ptr[g * D + d];
1480 db_val += buffer_ptr[g * D + d + C];
1481
1482 }
1483 ds_data[n * C + g * D + d] = ds_val;
1484 db_data[n * C + g * D + d] = db_val;
1485 }
1486 tmp_buffer_data[n * 2 * G + 2 * g] = ds_gamma;
1487 tmp_buffer_data[n * 2 * G + 2 * g + 1] = db_gamma;
1488 }
1489 }
1490
1491 // Step 3. Compute dx.
1492 if (dX_data != nullptr) {
1493 at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
1494 int64_t n{0}, m{0};
1495 data_index_init(begin, n, N, m, HxW);
1496 for (const auto i : c10::irange(begin, end)) {
1497 for (const auto g : c10::irange(G)) {
1498 const T* X_ptr = X_data + i * C + g * D;
1499 const T* dY_ptr = dY_data + i * C + g * D;
1500 T* dX_ptr = dX_data + i * C + g * D;
1501 const PT* mean_ptr = mean_data + n * G + g;
1502 const PT* rstd_ptr = rstd_data + n * G + g;
1503 const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
1504 opmath_t ds_val = tmp_buffer_data[n * 2 * G + 2 * g];
1505 opmath_t db_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];
1506
1507 const opmath_t c2 = (db_val * opmath_t(*mean_ptr) - ds_val) *
1508 opmath_t(*rstd_ptr) * opmath_t(*rstd_ptr)* opmath_t(*rstd_ptr) * s;
1509 const opmath_t c3 = -c2 * opmath_t(*mean_ptr) - db_val * opmath_t(*rstd_ptr) * s;
1510 ApplyInputGradientsChannelsLastRowMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
1511 }
1512
1513 data_index_step(n, N, m, HxW);
1514 }
1515 });
1516 }
1517
1518 }
1519
1520 // Finally compute dgamma and dbeta.
1521 if (dgamma_data != nullptr) {
1522 GammaBackward(
1523 N, C, group, mean_data, rstd_data, ds_data, db_data, dgamma_data);
1524 }
1525 if (dbeta_data != nullptr) {
1526 BetaBackward(N, C, db_data, dbeta_data);
1527 }
1528 }
1529
GroupNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & dX,Tensor & dgamma,Tensor & dbeta)1530 void GroupNormBackwardKernelImpl(
1531 const Tensor& dY,
1532 const Tensor& X,
1533 const Tensor& mean,
1534 const Tensor& rstd,
1535 const Tensor& gamma,
1536 int64_t N,
1537 int64_t C,
1538 int64_t HxW,
1539 int64_t group,
1540 Tensor& dX,
1541 Tensor& dgamma,
1542 Tensor& dbeta) {
1543 // In training, using Amp to enable lower precision data type,
1544 // i.e., BFloat16 or Half, is recommended.
1545 // It will keep module parameters in opmath dtype i.e. float
1546 // while input/output will be in lower precision data type.
1547 // Using parameters in BFloat16 or Half may cause high precision loss.
1548 const bool mixed_type = is_mixed_type(dY, mean);
1549 switch (X.suggest_memory_format()) {
1550 case at::MemoryFormat::Contiguous: {
1551 AT_DISPATCH_FLOATING_TYPES_AND2(
1552 ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() {
1553 using param_t = at::opmath_type<scalar_t>;
1554 if(mixed_type) {
1555 GroupNormBackwardKernelImplInternal<scalar_t, param_t>(
1556 dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1557 } else {
1558 GroupNormBackwardKernelImplInternal<scalar_t, scalar_t>(
1559 dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1560 }
1561 });
1562 break;
1563 }
1564 case at::MemoryFormat::ChannelsLast:
1565 case at::MemoryFormat::ChannelsLast3d: {
1566 AT_DISPATCH_FLOATING_TYPES_AND2(
1567 ScalarType::BFloat16, ScalarType::Half, X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() {
1568 using param_t = at::opmath_type<scalar_t>;
1569 if(mixed_type) {
1570 GroupNormBackwardKernelImplChannelsLastInternal<scalar_t, param_t>(
1571 dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1572 } else {
1573 GroupNormBackwardKernelImplChannelsLastInternal<scalar_t, scalar_t>(
1574 dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta);
1575 }
1576 });
1577 break;
1578 }
1579 default:
1580 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, ChannelsLast3d, Contiguous");
1581 }
1582
1583 }
1584
1585 } // namespace
1586
1587 REGISTER_DISPATCH(GroupNormKernel, &GroupNormKernelImpl);
1588 REGISTER_DISPATCH(GroupNormBackwardKernel, &GroupNormBackwardKernelImpl);
1589
1590 } // namespace at::native
1591