1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/layer_norm.h>
3
4 #include <cmath>
5 #include <tuple>
6
7 #include <ATen/core/Tensor.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/OpMathType.h>
10 #include <ATen/cpu/vec/functional.h>
11 #include <ATen/cpu/vec/vec.h>
12 #include <ATen/native/cpu/moments_utils.h>
13 #include <ATen/native/cpu/mixed_data_type.h>
14 #include <c10/util/irange.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #else
19 #include <ATen/ops/empty.h>
20 #endif
21
22 namespace at::native {
23
24 namespace {
25
26 template <typename T,
27 typename std::enable_if_t<!is_reduced_floating_point_v<T>, int> = 0>
LayerNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,T eps,Tensor * Y,Tensor * mean,Tensor * rstd)28 void LayerNormKernelImplInternal(
29 const Tensor& X,
30 const Tensor& gamma,
31 const Tensor& beta,
32 int64_t M,
33 int64_t N,
34 T eps,
35 Tensor* Y,
36 Tensor* mean,
37 Tensor* rstd) {
38 using Vec = vec::Vectorized<T>;
39 const T* X_data = X.const_data_ptr<T>();
40 const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
41 const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
42 T* Y_data = Y->data_ptr<T>();
43 T* mean_data = mean ? mean->data_ptr<T>() : nullptr;
44 T* rstd_data = rstd ? rstd->data_ptr<T>() : nullptr;
45
46 const bool gamma_null = gamma_data == nullptr;
47 const bool beta_null = beta_data == nullptr;
48 const bool mean_null = mean_data == nullptr;
49 const bool rstd_null = rstd_data == nullptr;
50 at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
51 for (const auto i : c10::irange(start, end)) {
52 const T* X_ptr = X_data + i * N;
53 T* Y_ptr = Y_data + i * N;
54 auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, N);
55 rstd_val = T(1) / std::sqrt(rstd_val + eps);
56 const T scale = rstd_val;
57 const T bias = - mean_val;
58 if (gamma_null || beta_null) {
59 for (const auto j : c10::irange(N)) {
60 const T gamma_v = gamma_null ? T(1) : gamma_data[j];
61 const T beta_v = beta_null ? T(0) : beta_data[j];
62 Y_ptr[j] = (X_ptr[j] + bias) * rstd_val * gamma_v + beta_v;
63 }
64 } else {
65 vec::map3<T>(
66 [scale, bias](Vec x, Vec gamma, Vec beta) {
67 return (x + Vec(bias)) * Vec(scale) * gamma + beta;
68 },
69 Y_ptr,
70 X_ptr,
71 gamma_data,
72 beta_data,
73 N);
74 }
75 if (!mean_null) {
76 mean_data[i] = mean_val;
77 }
78 if (!rstd_null) {
79 rstd_data[i] = rstd_val;
80 }
81 }
82 });
83 }
84
85 template <typename T, typename param_t,
86 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
layer_norm_kernel_mixed_type(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,float eps,Tensor * Y,Tensor * mean,Tensor * rstd)87 void layer_norm_kernel_mixed_type(
88 const Tensor& X,
89 const Tensor& gamma,
90 const Tensor& beta,
91 int64_t M,
92 int64_t N,
93 float eps,
94 Tensor* Y,
95 Tensor* mean,
96 Tensor* rstd) {
97 using bVec = Vectorized<T>;
98 using fVec = Vectorized<float>;
99 const T* X_data = X.const_data_ptr<T>();
100 const param_t* gamma_data = gamma.defined() ? gamma.const_data_ptr<param_t>() : nullptr;
101 const param_t* beta_data = beta.defined() ? beta.const_data_ptr<param_t>() : nullptr;
102 T* Y_data = Y->data_ptr<T>();
103 param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
104 param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;
105
106 const bool gamma_null = gamma_data == nullptr;
107 const bool beta_null = beta_data == nullptr;
108 const bool mean_null = mean_data == nullptr;
109 const bool rstd_null = rstd_data == nullptr;
110 at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
111 for (const auto i : c10::irange(start, end)) {
112 const T* X_ptr = X_data + i * N;
113 T* Y_ptr = Y_data + i * N;
114 auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, N);
115 rstd_val = float(1) / std::sqrt(rstd_val + eps);
116 const float scale = rstd_val;
117 const float bias = -rstd_val * mean_val;
118 int64_t d = 0;
119 for (; d < N - (N % bVec::size()); d += bVec::size()) {
120 bVec x_bvec = bVec::loadu(X_ptr + d);
121 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
122 auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
123 auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
124 fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
125 fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
126 bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
127 y_bvec.store(Y_ptr + d);
128 }
129 for (; d < N; d++) {
130 const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
131 const float beta_v = beta_null ? float(0) : float(beta_data[d]);
132 Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
133 }
134 if (!mean_null) {
135 mean_data[i] = mean_val;
136 }
137 if (!rstd_null) {
138 rstd_data[i] = rstd_val;
139 }
140 }
141 });
142 }
143
144 template <typename T,
145 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
LayerNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,float eps,Tensor * Y,Tensor * mean,Tensor * rstd)146 void LayerNormKernelImplInternal(
147 const Tensor& X,
148 const Tensor& gamma,
149 const Tensor& beta,
150 int64_t M,
151 int64_t N,
152 float eps,
153 Tensor* Y,
154 Tensor* mean,
155 Tensor* rstd) {
156 const bool mixed_type = is_mixed_type(X, gamma, beta);
157 if (mixed_type) {
158 layer_norm_kernel_mixed_type<T, float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
159 } else {
160 layer_norm_kernel_mixed_type<T, T>(X, gamma, beta, M, N, eps, Y, mean, rstd);
161 }
162 }
163
LayerNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,double eps,Tensor * Y,Tensor * mean,Tensor * rstd)164 void LayerNormKernelImpl(
165 const Tensor& X,
166 const Tensor& gamma,
167 const Tensor& beta,
168 int64_t M,
169 int64_t N,
170 double eps,
171 Tensor* Y,
172 Tensor* mean,
173 Tensor* rstd) {
174 TORCH_DCHECK_EQ(X.numel(), M * N);
175 DCHECK(!gamma.defined() || gamma.numel() == N);
176 DCHECK(!beta.defined() || beta.numel() == N);
177 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, X.scalar_type(),
178 "LayerNormKernelImpl", [&]() {
179 LayerNormKernelImplInternal<scalar_t>(
180 X, gamma, beta, M, N, eps, Y, mean, rstd);
181 });
182 }
183
184 template <typename T, typename T2, typename opmath_t>
layer_norm_backward_frame(const T * dY_data,const T * X_data,const T2 * mean_data,const T2 * rstd_data,const T2 * gamma_data,T * dX_data,T * dgamma_buffer_ptr,T * dbeta_buffer_ptr,const opmath_t scale,const bool gamma_null,const bool dX_null,const bool dgamma_null,const bool dbeta_null,int64_t N,int64_t i)185 void layer_norm_backward_frame(
186 const T* dY_data,
187 const T* X_data,
188 const T2* mean_data,
189 const T2* rstd_data,
190 const T2* gamma_data,
191 T* dX_data,
192 T* dgamma_buffer_ptr,
193 T* dbeta_buffer_ptr,
194 const opmath_t scale,
195 const bool gamma_null,
196 const bool dX_null,
197 const bool dgamma_null,
198 const bool dbeta_null,
199 int64_t N,
200 int64_t i) {
201 using Vec = vec::Vectorized<opmath_t>;
202 const T* dY_ptr = dY_data + i * N;
203 const T* X_ptr = X_data + i * N;
204 if (!dgamma_null) {
205 const opmath_t a = rstd_data[i];
206 const opmath_t b = -a * mean_data[i];
207 // Scalar math:
208 // for (const auto j : c10::irange(N)) {
209 // dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
210 // }
211 vec::map3<T>(
212 [a, b](Vec dgamma, Vec dy, Vec x) {
213 return dgamma + dy * (Vec(a) * x + Vec(b));
214 },
215 dgamma_buffer_ptr,
216 dgamma_buffer_ptr,
217 dY_ptr,
218 X_ptr,
219 N);
220 }
221 if (!dbeta_null) {
222 // Scalar math:
223 // for (const auto j : c10::irange(N)) {
224 // dbeta_data[j] += dY_ptr[j];
225 // }
226 vec::map2<T>(
227 [](Vec dbeta, Vec dy) { return dbeta + dy; },
228 dbeta_buffer_ptr,
229 dbeta_buffer_ptr,
230 dY_ptr,
231 N);
232 }
233 if (!dX_null) {
234 T* dX_ptr = dX_data + i * N;
235 opmath_t ds = opmath_t(0);
236 opmath_t db = opmath_t(0);
237 // Scalar math:
238 // for (const auto j : c10::irange(N)) {
239 // const T gamma_v = gamma_null ? T(1) : gamma_data[j];
240 // ds += dY_ptr[j] * X_ptr[j] * gamma_v;
241 // db += dY_ptr[j] * gamma_v;
242 // }
243 if (gamma_null) {
244 ds = vec::map2_reduce_all<T>(
245 [](Vec x, Vec y) { return x * y; },
246 [](Vec x, Vec y) { return x + y; },
247 dY_ptr,
248 X_ptr,
249 N);
250 db = vec::reduce_all<T>(
251 [](Vec& x, Vec& y) { return x + y; }, dY_ptr, N);
252 } else {
253 ds = vec::map3_reduce_all<T>(
254 [](Vec x, Vec y, Vec z) { return x * y * z; },
255 [](Vec x, Vec y) { return x + y; },
256 dY_ptr,
257 X_ptr,
258 gamma_data,
259 N);
260 db = vec::map2_reduce_all<T>(
261 [](Vec x, Vec y) { return x * y; },
262 [](Vec x, Vec y) { return x + y; },
263 dY_ptr,
264 gamma_data,
265 N);
266 }
267 const opmath_t a = rstd_data[i];
268 const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
269 const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
270 // Scalar math:
271 // for (const auto j : c10::irange(N)) {
272 // const T gamma_v = gamma_null ? T(1) : gamma_data[j];
273 // dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
274 // }
275 if (gamma_null) {
276 vec::map2<T>(
277 [a, b, c](Vec dy, Vec x) {
278 return Vec(a) * dy + Vec(b) * x + Vec(c);
279 },
280 dX_ptr,
281 dY_ptr,
282 X_ptr,
283 N);
284 } else {
285 vec::map3<T>(
286 [a, b, c](Vec dy, Vec gamma, Vec x) {
287 return Vec(a) * dy * gamma + Vec(b) * x + Vec(c);
288 },
289 dX_ptr,
290 dY_ptr,
291 gamma_data,
292 X_ptr,
293 N);
294 }
295 }
296 }
297
298 template <typename T, typename T2, typename opmath_t,
299 typename std::enable_if_t<is_reduced_floating_point_v<T> && std::is_same<T2, float>::value, int> = 0>
layer_norm_backward_frame(const T * dY_data,const T * X_data,const float * mean_data,const float * rstd_data,const float * gamma_data,T * dX_data,T * dgamma_buffer_ptr,T * dbeta_buffer_ptr,const float scale,const bool gamma_null,const bool dX_null,const bool dgamma_null,const bool dbeta_null,int64_t N,int64_t i)300 void layer_norm_backward_frame(
301 const T* dY_data,
302 const T* X_data,
303 const float* mean_data,
304 const float* rstd_data,
305 const float* gamma_data,
306 T* dX_data,
307 T* dgamma_buffer_ptr,
308 T* dbeta_buffer_ptr,
309 const float scale,
310 const bool gamma_null,
311 const bool dX_null,
312 const bool dgamma_null,
313 const bool dbeta_null,
314 int64_t N,
315 int64_t i) {
316 using bVec = Vectorized<T>;
317 using fVec = Vectorized<float>;
318 const T* dY_ptr = dY_data + i * N;
319 const T* X_ptr = X_data + i * N;
320 if (!dgamma_null) {
321 const float a = rstd_data[i];
322 const float b = -a * mean_data[i];
323 // Scalar math:
324 // for (const auto j : c10::irange(N)) {
325 // dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
326 // }
327 vec::map3<T>(
328 [a, b](fVec dgamma, fVec dy, fVec x) {
329 return dgamma + dy * (fVec(a) * x + fVec(b));
330 },
331 dgamma_buffer_ptr,
332 dgamma_buffer_ptr,
333 dY_ptr,
334 X_ptr,
335 N);
336 }
337 if (!dbeta_null) {
338 // Scalar math:
339 // for (const auto j : c10::irange(N)) {
340 // dbeta_data[j] += dY_ptr[j];
341 // }
342 vec::map2<T>(
343 [](fVec dbeta, fVec dy) { return dbeta + dy; },
344 dbeta_buffer_ptr,
345 dbeta_buffer_ptr,
346 dY_ptr,
347 N);
348 }
349 if (!dX_null) {
350 T* dX_ptr = dX_data + i * N;
351 float ds = float(0);
352 float db = float(0);
353 // Scalar math:
354 // for (const auto j : c10::irange(N)) {
355 // const T gamma_v = gamma_null ? T(1) : gamma_data[j];
356 // ds += dY_ptr[j] * X_ptr[j] * gamma_v;
357 // db += dY_ptr[j] * gamma_v;
358 // }
359 if (gamma_null) {
360 ds = vec::map2_reduce_all<T>(
361 [](fVec x, fVec y) { return x * y; },
362 [](fVec x, fVec y) { return x + y; },
363 dY_ptr,
364 X_ptr,
365 N);
366 db = vec::reduce_all<T>(
367 [](fVec& x, fVec& y) { return x + y; }, dY_ptr, N);
368 } else {
369 if (N < bVec::size()) {
370 bVec x_bvec = bVec::loadu(X_ptr, N);
371 bVec dy_bvec = bVec::loadu(dY_ptr, N);
372 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
373 auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
374 auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data, N);
375 if (N > fVec::size()) {
376 fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
377 fVec db_fvec1 = dy_fvec1 * gamma_fvec1;
378 fVec ds_fvec0 = x_fvec0 * db_fvec0;
379 fVec ds_fvec1 = x_fvec1 * db_fvec1;
380 ds_fvec0 = fVec::set(ds_fvec0, ds_fvec0 + ds_fvec1, N - fVec::size());
381 ds = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, ds_fvec0);
382 db_fvec0 = fVec::set(db_fvec0, db_fvec0 + db_fvec1, N - fVec::size());
383 db = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, db_fvec0);
384 } else {
385 fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
386 fVec ds_fvec0 = x_fvec0 * db_fvec0;
387 ds = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, ds_fvec0, N);
388 db = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, db_fvec0, N);
389 }
390 } else {
391 int64_t d = bVec::size();
392 bVec x_bvec = bVec::loadu(X_ptr);
393 bVec dy_bvec = bVec::loadu(dY_ptr);
394 fVec ds_fvec0, ds_fvec1, db_fvec0, db_fvec1, acc_ds_fvec0, acc_ds_fvec1, acc_db_fvec0, acc_db_fvec1;
395 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
396 auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
397 auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data);
398 acc_db_fvec0 = dy_fvec0 * gamma_fvec0;
399 acc_db_fvec1 = dy_fvec1 * gamma_fvec1;
400 acc_ds_fvec0 = x_fvec0 * acc_db_fvec0;
401 acc_ds_fvec1 = x_fvec1 * acc_db_fvec1;
402 for (; d < N - (N % bVec::size()); d += bVec::size()) {
403 x_bvec = bVec::loadu(X_ptr + d);
404 dy_bvec = bVec::loadu(dY_ptr + d);
405 std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
406 std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
407 std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
408 db_fvec0 = dy_fvec0 * gamma_fvec0;
409 db_fvec1 = dy_fvec1 * gamma_fvec1;
410 ds_fvec0 = x_fvec0 * db_fvec0;
411 ds_fvec1 = x_fvec1 * db_fvec1;
412 acc_ds_fvec0 = acc_ds_fvec0 + ds_fvec0;
413 acc_ds_fvec1 = acc_ds_fvec1 + ds_fvec1;
414 acc_db_fvec0 = acc_db_fvec0 + db_fvec0;
415 acc_db_fvec1 = acc_db_fvec1 + db_fvec1;
416 }
417 if (N - d > 0) {
418 x_bvec = bVec::loadu(X_ptr + d, N - d);
419 dy_bvec = bVec::loadu(dY_ptr + d, N - d);
420 std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
421 std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
422 std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
423 if (N - d > fVec::size()) {
424 db_fvec0 = dy_fvec0 * gamma_fvec0;
425 db_fvec1 = dy_fvec1 * gamma_fvec1;
426 ds_fvec0 = x_fvec0 * db_fvec0;
427 ds_fvec1 = x_fvec1 * db_fvec1;
428 acc_ds_fvec0 = acc_ds_fvec0 + ds_fvec0;
429 acc_ds_fvec1 = fVec::set(acc_ds_fvec1, acc_ds_fvec1 + ds_fvec1, N - d - fVec::size());
430 acc_db_fvec0 = acc_db_fvec0 + db_fvec0;
431 acc_db_fvec1 = fVec::set(acc_db_fvec1, acc_db_fvec1 + db_fvec1, N - d - fVec::size());
432 } else {
433 db_fvec0 = dy_fvec0 * gamma_fvec0;
434 ds_fvec0 = x_fvec0 * db_fvec0;
435 acc_ds_fvec0 = fVec::set(acc_ds_fvec0, acc_ds_fvec0 + ds_fvec0, N - d);
436 acc_db_fvec0 = fVec::set(acc_db_fvec0, acc_db_fvec0 + db_fvec0, N - d);
437 }
438 }
439 acc_ds_fvec0 = acc_ds_fvec0 + acc_ds_fvec1;
440 acc_db_fvec0 = acc_db_fvec0 + acc_db_fvec1;
441 ds = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, acc_ds_fvec0);
442 db = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, acc_db_fvec0);
443 }
444 }
445 const float a = rstd_data[i];
446 const float b = (db * mean_data[i] - ds) * a * a * a * scale;
447 const float c = -b * mean_data[i] - db * a * scale;
448 // Scalar math:
449 // for (const auto j : c10::irange(N)) {
450 // const T gamma_v = gamma_null ? T(1) : gamma_data[j];
451 // dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
452 // }
453 if (gamma_null) {
454 vec::map2<T>(
455 [a, b, c](fVec dy, fVec x) {
456 return fVec(a) * dy + fVec(b) * x + fVec(c);
457 },
458 dX_ptr,
459 dY_ptr,
460 X_ptr,
461 N);
462 } else {
463 int64_t d = 0;
464 for (; d < N - (N % bVec::size()); d += bVec::size()) {
465 bVec x_bvec = bVec::loadu(X_ptr + d);
466 bVec dy_bvec = bVec::loadu(dY_ptr + d);
467 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
468 auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
469 auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data + d);
470 fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
471 fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
472 bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
473 r_bvec.store(dX_ptr + d);
474 }
475 if (N - d > 0) {
476 bVec x_bvec = bVec::loadu(X_ptr + d, N - d);
477 bVec dy_bvec = bVec::loadu(dY_ptr + d, N - d);
478 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
479 auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
480 auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data + d, N - d);
481 fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
482 fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
483 bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
484 r_bvec.store(dX_ptr + d, N - d);
485 }
486 }
487 }
488 }
489
490 template <typename T, typename T2>
LayerNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)491 void LayerNormBackwardKernelImplInternal(
492 const Tensor& dY,
493 const Tensor& X,
494 const Tensor& mean,
495 const Tensor& rstd,
496 const Tensor& gamma,
497 int64_t M,
498 int64_t N,
499 Tensor* dX,
500 Tensor* dgamma,
501 Tensor* dbeta) {
502 using opmath_t = at::opmath_type<T>;
503 TORCH_DCHECK_EQ(dY.numel(), M * N);
504 TORCH_DCHECK_EQ(X.numel(), M * N);
505 TORCH_DCHECK_EQ(mean.numel(), M);
506 TORCH_DCHECK_EQ(rstd.numel(), M);
507 DCHECK(!gamma.defined() || gamma.numel() == N);
508 const T* dY_data = dY.template const_data_ptr<T>();
509 const T* X_data = X.template const_data_ptr<T>();
510 const T2* mean_data = mean.template const_data_ptr<T2>();
511 const T2* rstd_data = rstd.template const_data_ptr<T2>();
512 const T2* gamma_data =
513 gamma.defined() ? gamma.template const_data_ptr<T2>() : nullptr;
514 T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
515 T2* dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
516 T2* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
517 const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
518 const bool gamma_null = gamma_data == nullptr;
519 const bool dX_null = dX_data == nullptr;
520 const bool dgamma_null = dgamma_data == nullptr;
521 const bool dbeta_null = dbeta_data == nullptr;
522
523 // 1. Use two path parallel reduction for dgamma and dbeta:
524 // First path: allocate an immediate buffer of size {2, max_threads, N},
525 // dgamma_buffer = buffer[0], dbeta_buffer = buffer[1]
526 // Parallel along dim0 and reduce dY and X along dim0 to buffer.
527 // Second path: parallel along dim1 and reduce buffer to dgamma and dbeta.
528 //
529 // 2. Fuse first path of dgamma/dbeta with dX to reuse X[i] and dY[i] in L1
530 // cache.
531 //
532 int num_threads = at::get_num_threads();
533 Tensor buffer = at::empty({0}, X.options());
534 T* buffer_data = nullptr;
535 if (!dgamma_null || !dbeta_null) {
536 // zero the immediate buffer and skip zero dgamma and dbeta
537 buffer.resize_({2, num_threads, N}).zero_();
538 buffer_data = buffer.template data_ptr<T>();
539 }
540
541 // First path of dgamma/dbeta and dX
542 at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
543 int tid = at::get_thread_num();
544 TORCH_CHECK(
545 tid < num_threads,
546 "expect thread id smaller than ",
547 num_threads,
548 ", got thread id ",
549 tid);
550 T* dgamma_buffer_ptr = dgamma_null ? nullptr : buffer_data + tid * N;
551 T* dbeta_buffer_ptr =
552 dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
553 for (const auto i : c10::irange(start, end)) {
554 layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
555 }
556 });
557
558 // Second path of dgamma/dbeta
559 if (buffer_data != nullptr) {
560 parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
561 for (const auto j : c10::irange(start, end)) {
562 opmath_t dgamma_v = opmath_t(0);
563 opmath_t dbeta_v = opmath_t(0);
564 for (const auto i : c10::irange(num_threads)) {
565 dgamma_v += buffer_data[i * N + j];
566 dbeta_v += buffer_data[num_threads * N + i * N + j];
567 }
568 if (!dgamma_null) {
569 // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
570 dgamma_data[j] = dgamma_v;
571 }
572 if (!dbeta_null) {
573 // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
574 dbeta_data[j] = dbeta_v;
575 }
576 }
577 });
578 }
579 }
580
LayerNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)581 void LayerNormBackwardKernelImpl(
582 const Tensor& dY,
583 const Tensor& X,
584 const Tensor& mean,
585 const Tensor& rstd,
586 const Tensor& gamma,
587 int64_t M,
588 int64_t N,
589 Tensor* dX,
590 Tensor* dgamma,
591 Tensor* dbeta) {
592 if (at::isReducedFloatingType(X.scalar_type())) {
593 AT_DISPATCH_REDUCED_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
594 if (gamma.scalar_type() == at::kFloat) {
595 LayerNormBackwardKernelImplInternal<scalar_t, float>(
596 dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
597 } else {
598 LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
599 dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
600 }
601 });
602 } else {
603 AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
604 LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
605 dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
606 });
607 }
608 }
609
610 } // namespace
611
612 REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl);
613 REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl);
614
615 } // namespace at::native
616