xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/layer_norm_kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/layer_norm.h>
3 
4 #include <cmath>
5 #include <tuple>
6 
7 #include <ATen/core/Tensor.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/OpMathType.h>
10 #include <ATen/cpu/vec/functional.h>
11 #include <ATen/cpu/vec/vec.h>
12 #include <ATen/native/cpu/moments_utils.h>
13 #include <ATen/native/cpu/mixed_data_type.h>
14 #include <c10/util/irange.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #else
19 #include <ATen/ops/empty.h>
20 #endif
21 
22 namespace at::native {
23 
24 namespace {
25 
26 template <typename T,
27           typename std::enable_if_t<!is_reduced_floating_point_v<T>, int> = 0>
LayerNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,T eps,Tensor * Y,Tensor * mean,Tensor * rstd)28 void LayerNormKernelImplInternal(
29     const Tensor& X,
30     const Tensor& gamma,
31     const Tensor& beta,
32     int64_t M,
33     int64_t N,
34     T eps,
35     Tensor* Y,
36     Tensor* mean,
37     Tensor* rstd) {
38   using Vec = vec::Vectorized<T>;
39   const T* X_data = X.const_data_ptr<T>();
40   const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
41   const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
42   T* Y_data = Y->data_ptr<T>();
43   T* mean_data = mean ? mean->data_ptr<T>() : nullptr;
44   T* rstd_data = rstd ? rstd->data_ptr<T>() : nullptr;
45 
46   const bool gamma_null = gamma_data == nullptr;
47   const bool beta_null = beta_data == nullptr;
48   const bool mean_null = mean_data == nullptr;
49   const bool rstd_null = rstd_data == nullptr;
50   at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
51     for (const auto i : c10::irange(start, end)) {
52       const T* X_ptr = X_data + i * N;
53       T* Y_ptr = Y_data + i * N;
54       auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, N);
55       rstd_val = T(1) / std::sqrt(rstd_val + eps);
56       const T scale = rstd_val;
57       const T bias = - mean_val;
58       if (gamma_null || beta_null) {
59         for (const auto j : c10::irange(N)) {
60           const T gamma_v = gamma_null ? T(1) : gamma_data[j];
61           const T beta_v = beta_null ? T(0) : beta_data[j];
62           Y_ptr[j] = (X_ptr[j] + bias) * rstd_val * gamma_v + beta_v;
63         }
64       } else {
65         vec::map3<T>(
66             [scale, bias](Vec x, Vec gamma, Vec beta) {
67               return (x + Vec(bias)) * Vec(scale) * gamma + beta;
68             },
69             Y_ptr,
70             X_ptr,
71             gamma_data,
72             beta_data,
73             N);
74       }
75       if (!mean_null) {
76         mean_data[i] = mean_val;
77       }
78       if (!rstd_null) {
79         rstd_data[i] = rstd_val;
80       }
81     }
82   });
83 }
84 
85 template <typename T, typename param_t,
86           typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
layer_norm_kernel_mixed_type(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,float eps,Tensor * Y,Tensor * mean,Tensor * rstd)87 void layer_norm_kernel_mixed_type(
88     const Tensor& X,
89     const Tensor& gamma,
90     const Tensor& beta,
91     int64_t M,
92     int64_t N,
93     float eps,
94     Tensor* Y,
95     Tensor* mean,
96     Tensor* rstd) {
97   using bVec = Vectorized<T>;
98   using fVec = Vectorized<float>;
99   const T* X_data = X.const_data_ptr<T>();
100   const param_t* gamma_data = gamma.defined() ? gamma.const_data_ptr<param_t>() : nullptr;
101   const param_t* beta_data = beta.defined() ? beta.const_data_ptr<param_t>() : nullptr;
102   T* Y_data = Y->data_ptr<T>();
103   param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
104   param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;
105 
106   const bool gamma_null = gamma_data == nullptr;
107   const bool beta_null = beta_data == nullptr;
108   const bool mean_null = mean_data == nullptr;
109   const bool rstd_null = rstd_data == nullptr;
110   at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
111     for (const auto i : c10::irange(start, end)) {
112       const T* X_ptr = X_data + i * N;
113       T* Y_ptr = Y_data + i * N;
114       auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, N);
115       rstd_val = float(1) / std::sqrt(rstd_val + eps);
116       const float scale = rstd_val;
117       const float bias = -rstd_val * mean_val;
118       int64_t d = 0;
119       for (; d < N - (N % bVec::size()); d += bVec::size()) {
120         bVec x_bvec = bVec::loadu(X_ptr + d);
121         auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
122         auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
123         auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
124         fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
125         fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
126         bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
127         y_bvec.store(Y_ptr + d);
128       }
129       for (; d < N; d++) {
130         const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
131         const float beta_v = beta_null ? float(0) : float(beta_data[d]);
132         Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
133       }
134       if (!mean_null) {
135         mean_data[i] = mean_val;
136       }
137       if (!rstd_null) {
138         rstd_data[i] = rstd_val;
139       }
140     }
141   });
142 }
143 
144 template <typename T,
145           typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
LayerNormKernelImplInternal(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,float eps,Tensor * Y,Tensor * mean,Tensor * rstd)146 void LayerNormKernelImplInternal(
147     const Tensor& X,
148     const Tensor& gamma,
149     const Tensor& beta,
150     int64_t M,
151     int64_t N,
152     float eps,
153     Tensor* Y,
154     Tensor* mean,
155     Tensor* rstd) {
156   const bool mixed_type = is_mixed_type(X, gamma, beta);
157   if (mixed_type) {
158     layer_norm_kernel_mixed_type<T, float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
159   } else {
160     layer_norm_kernel_mixed_type<T, T>(X, gamma, beta, M, N, eps, Y, mean, rstd);
161   }
162 }
163 
LayerNormKernelImpl(const Tensor & X,const Tensor & gamma,const Tensor & beta,int64_t M,int64_t N,double eps,Tensor * Y,Tensor * mean,Tensor * rstd)164 void LayerNormKernelImpl(
165     const Tensor& X,
166     const Tensor& gamma,
167     const Tensor& beta,
168     int64_t M,
169     int64_t N,
170     double eps,
171     Tensor* Y,
172     Tensor* mean,
173     Tensor* rstd) {
174   TORCH_DCHECK_EQ(X.numel(), M * N);
175   DCHECK(!gamma.defined() || gamma.numel() == N);
176   DCHECK(!beta.defined() || beta.numel() == N);
177   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, X.scalar_type(),
178       "LayerNormKernelImpl", [&]() {
179     LayerNormKernelImplInternal<scalar_t>(
180         X, gamma, beta, M, N, eps, Y, mean, rstd);
181   });
182 }
183 
184 template <typename T, typename T2, typename opmath_t>
layer_norm_backward_frame(const T * dY_data,const T * X_data,const T2 * mean_data,const T2 * rstd_data,const T2 * gamma_data,T * dX_data,T * dgamma_buffer_ptr,T * dbeta_buffer_ptr,const opmath_t scale,const bool gamma_null,const bool dX_null,const bool dgamma_null,const bool dbeta_null,int64_t N,int64_t i)185 void layer_norm_backward_frame(
186     const T* dY_data,
187     const T* X_data,
188     const T2* mean_data,
189     const T2* rstd_data,
190     const T2* gamma_data,
191     T* dX_data,
192     T* dgamma_buffer_ptr,
193     T* dbeta_buffer_ptr,
194     const opmath_t scale,
195     const bool gamma_null,
196     const bool dX_null,
197     const bool dgamma_null,
198     const bool dbeta_null,
199     int64_t N,
200     int64_t i) {
201   using Vec = vec::Vectorized<opmath_t>;
202   const T* dY_ptr = dY_data + i * N;
203   const T* X_ptr = X_data + i * N;
204   if (!dgamma_null) {
205     const opmath_t a = rstd_data[i];
206     const opmath_t b = -a * mean_data[i];
207     // Scalar math:
208     // for (const auto j : c10::irange(N)) {
209     //   dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
210     // }
211     vec::map3<T>(
212         [a, b](Vec dgamma, Vec dy, Vec x) {
213           return dgamma + dy * (Vec(a) * x + Vec(b));
214         },
215         dgamma_buffer_ptr,
216         dgamma_buffer_ptr,
217         dY_ptr,
218         X_ptr,
219         N);
220   }
221   if (!dbeta_null) {
222     // Scalar math:
223     // for (const auto j : c10::irange(N)) {
224     //   dbeta_data[j] += dY_ptr[j];
225     // }
226     vec::map2<T>(
227         [](Vec dbeta, Vec dy) { return dbeta + dy; },
228         dbeta_buffer_ptr,
229         dbeta_buffer_ptr,
230         dY_ptr,
231         N);
232   }
233   if (!dX_null) {
234     T* dX_ptr = dX_data + i * N;
235     opmath_t ds = opmath_t(0);
236     opmath_t db = opmath_t(0);
237     // Scalar math:
238     // for (const auto j : c10::irange(N)) {
239     //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
240     //   ds += dY_ptr[j] * X_ptr[j] * gamma_v;
241     //   db += dY_ptr[j] * gamma_v;
242     // }
243     if (gamma_null) {
244       ds = vec::map2_reduce_all<T>(
245           [](Vec x, Vec y) { return x * y; },
246           [](Vec x, Vec y) { return x + y; },
247           dY_ptr,
248           X_ptr,
249           N);
250       db = vec::reduce_all<T>(
251           [](Vec& x, Vec& y) { return x + y; }, dY_ptr, N);
252     } else {
253       ds = vec::map3_reduce_all<T>(
254           [](Vec x, Vec y, Vec z) { return x * y * z; },
255           [](Vec x, Vec y) { return x + y; },
256           dY_ptr,
257           X_ptr,
258           gamma_data,
259           N);
260       db = vec::map2_reduce_all<T>(
261           [](Vec x, Vec y) { return x * y; },
262           [](Vec x, Vec y) { return x + y; },
263           dY_ptr,
264           gamma_data,
265           N);
266     }
267     const opmath_t a = rstd_data[i];
268     const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
269     const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
270     // Scalar math:
271     // for (const auto j : c10::irange(N)) {
272     //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
273     //   dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
274     // }
275     if (gamma_null) {
276       vec::map2<T>(
277           [a, b, c](Vec dy, Vec x) {
278             return Vec(a) * dy + Vec(b) * x + Vec(c);
279           },
280           dX_ptr,
281           dY_ptr,
282           X_ptr,
283           N);
284     } else {
285       vec::map3<T>(
286           [a, b, c](Vec dy, Vec gamma, Vec x) {
287             return Vec(a) * dy * gamma + Vec(b) * x + Vec(c);
288           },
289           dX_ptr,
290           dY_ptr,
291           gamma_data,
292           X_ptr,
293           N);
294     }
295   }
296 }
297 
298 template <typename T, typename T2, typename opmath_t,
299           typename std::enable_if_t<is_reduced_floating_point_v<T> && std::is_same<T2, float>::value, int> = 0>
layer_norm_backward_frame(const T * dY_data,const T * X_data,const float * mean_data,const float * rstd_data,const float * gamma_data,T * dX_data,T * dgamma_buffer_ptr,T * dbeta_buffer_ptr,const float scale,const bool gamma_null,const bool dX_null,const bool dgamma_null,const bool dbeta_null,int64_t N,int64_t i)300 void layer_norm_backward_frame(
301     const T* dY_data,
302     const T* X_data,
303     const float* mean_data,
304     const float* rstd_data,
305     const float* gamma_data,
306     T* dX_data,
307     T* dgamma_buffer_ptr,
308     T* dbeta_buffer_ptr,
309     const float scale,
310     const bool gamma_null,
311     const bool dX_null,
312     const bool dgamma_null,
313     const bool dbeta_null,
314     int64_t N,
315     int64_t i) {
316   using bVec = Vectorized<T>;
317   using fVec = Vectorized<float>;
318   const T* dY_ptr = dY_data + i * N;
319   const T* X_ptr = X_data + i * N;
320   if (!dgamma_null) {
321     const float a = rstd_data[i];
322     const float b = -a * mean_data[i];
323     // Scalar math:
324     // for (const auto j : c10::irange(N)) {
325     //   dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
326     // }
327     vec::map3<T>(
328         [a, b](fVec dgamma, fVec dy, fVec x) {
329           return dgamma + dy * (fVec(a) * x + fVec(b));
330         },
331         dgamma_buffer_ptr,
332         dgamma_buffer_ptr,
333         dY_ptr,
334         X_ptr,
335         N);
336   }
337   if (!dbeta_null) {
338     // Scalar math:
339     // for (const auto j : c10::irange(N)) {
340     //   dbeta_data[j] += dY_ptr[j];
341     // }
342     vec::map2<T>(
343         [](fVec dbeta, fVec dy) { return dbeta + dy; },
344         dbeta_buffer_ptr,
345         dbeta_buffer_ptr,
346         dY_ptr,
347         N);
348   }
349   if (!dX_null) {
350     T* dX_ptr = dX_data + i * N;
351     float ds = float(0);
352     float db = float(0);
353     // Scalar math:
354     // for (const auto j : c10::irange(N)) {
355     //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
356     //   ds += dY_ptr[j] * X_ptr[j] * gamma_v;
357     //   db += dY_ptr[j] * gamma_v;
358     // }
359     if (gamma_null) {
360       ds = vec::map2_reduce_all<T>(
361           [](fVec x, fVec y) { return x * y; },
362           [](fVec x, fVec y) { return x + y; },
363           dY_ptr,
364           X_ptr,
365           N);
366       db = vec::reduce_all<T>(
367           [](fVec& x, fVec& y) { return x + y; }, dY_ptr, N);
368     } else {
369       if (N < bVec::size()) {
370         bVec x_bvec = bVec::loadu(X_ptr, N);
371         bVec dy_bvec = bVec::loadu(dY_ptr, N);
372         auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
373         auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
374         auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data, N);
375         if (N > fVec::size()) {
376           fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
377           fVec db_fvec1 = dy_fvec1 * gamma_fvec1;
378           fVec ds_fvec0 = x_fvec0 * db_fvec0;
379           fVec ds_fvec1 = x_fvec1 * db_fvec1;
380           ds_fvec0 = fVec::set(ds_fvec0, ds_fvec0 + ds_fvec1, N - fVec::size());
381           ds = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, ds_fvec0);
382           db_fvec0 = fVec::set(db_fvec0, db_fvec0 + db_fvec1, N - fVec::size());
383           db = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, db_fvec0);
384         } else {
385           fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
386           fVec ds_fvec0 = x_fvec0 * db_fvec0;
387           ds = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, ds_fvec0, N);
388           db = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, db_fvec0, N);
389         }
390       } else {
391         int64_t d = bVec::size();
392         bVec x_bvec = bVec::loadu(X_ptr);
393         bVec dy_bvec = bVec::loadu(dY_ptr);
394         fVec ds_fvec0, ds_fvec1, db_fvec0, db_fvec1, acc_ds_fvec0, acc_ds_fvec1, acc_db_fvec0, acc_db_fvec1;
395         auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
396         auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
397         auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data);
398         acc_db_fvec0 = dy_fvec0 * gamma_fvec0;
399         acc_db_fvec1 = dy_fvec1 * gamma_fvec1;
400         acc_ds_fvec0 = x_fvec0 * acc_db_fvec0;
401         acc_ds_fvec1 = x_fvec1 * acc_db_fvec1;
402         for (; d < N - (N % bVec::size()); d += bVec::size()) {
403           x_bvec = bVec::loadu(X_ptr + d);
404           dy_bvec = bVec::loadu(dY_ptr + d);
405           std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
406           std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
407           std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
408           db_fvec0 = dy_fvec0 * gamma_fvec0;
409           db_fvec1 = dy_fvec1 * gamma_fvec1;
410           ds_fvec0 = x_fvec0 * db_fvec0;
411           ds_fvec1 = x_fvec1 * db_fvec1;
412           acc_ds_fvec0 = acc_ds_fvec0 + ds_fvec0;
413           acc_ds_fvec1 = acc_ds_fvec1 + ds_fvec1;
414           acc_db_fvec0 = acc_db_fvec0 + db_fvec0;
415           acc_db_fvec1 = acc_db_fvec1 + db_fvec1;
416         }
417         if (N - d > 0) {
418           x_bvec = bVec::loadu(X_ptr + d, N - d);
419           dy_bvec = bVec::loadu(dY_ptr + d, N - d);
420           std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
421           std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
422           std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
423           if (N - d > fVec::size()) {
424             db_fvec0 = dy_fvec0 * gamma_fvec0;
425             db_fvec1 = dy_fvec1 * gamma_fvec1;
426             ds_fvec0 = x_fvec0 * db_fvec0;
427             ds_fvec1 = x_fvec1 * db_fvec1;
428             acc_ds_fvec0 = acc_ds_fvec0 + ds_fvec0;
429             acc_ds_fvec1 = fVec::set(acc_ds_fvec1, acc_ds_fvec1 + ds_fvec1, N - d - fVec::size());
430             acc_db_fvec0 = acc_db_fvec0 + db_fvec0;
431             acc_db_fvec1 = fVec::set(acc_db_fvec1, acc_db_fvec1 + db_fvec1, N - d - fVec::size());
432           } else {
433             db_fvec0 = dy_fvec0 * gamma_fvec0;
434             ds_fvec0 = x_fvec0 * db_fvec0;
435             acc_ds_fvec0 = fVec::set(acc_ds_fvec0, acc_ds_fvec0 + ds_fvec0, N - d);
436             acc_db_fvec0 = fVec::set(acc_db_fvec0, acc_db_fvec0 + db_fvec0, N - d);
437           }
438         }
439         acc_ds_fvec0 = acc_ds_fvec0 + acc_ds_fvec1;
440         acc_db_fvec0 = acc_db_fvec0 + acc_db_fvec1;
441         ds = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, acc_ds_fvec0);
442         db = vec_reduce_all<float>([](fVec x, fVec y) { return x + y; }, acc_db_fvec0);
443       }
444     }
445     const float a = rstd_data[i];
446     const float b = (db * mean_data[i] - ds) * a * a * a * scale;
447     const float c = -b * mean_data[i] - db * a * scale;
448     // Scalar math:
449     // for (const auto j : c10::irange(N)) {
450     //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
451     //   dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
452     // }
453     if (gamma_null) {
454       vec::map2<T>(
455           [a, b, c](fVec dy, fVec x) {
456             return fVec(a) * dy + fVec(b) * x + fVec(c);
457           },
458           dX_ptr,
459           dY_ptr,
460           X_ptr,
461           N);
462     } else {
463       int64_t d = 0;
464       for (; d < N - (N % bVec::size()); d += bVec::size()) {
465         bVec x_bvec = bVec::loadu(X_ptr + d);
466         bVec dy_bvec = bVec::loadu(dY_ptr + d);
467         auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
468         auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
469         auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data + d);
470         fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
471         fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
472         bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
473         r_bvec.store(dX_ptr + d);
474       }
475       if (N - d > 0) {
476         bVec x_bvec = bVec::loadu(X_ptr + d, N - d);
477         bVec dy_bvec = bVec::loadu(dY_ptr + d, N - d);
478         auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
479         auto [dy_fvec0, dy_fvec1] = convert_to_float<T>(dy_bvec);
480         auto [gamma_fvec0, gamma_fvec1] = load2f(gamma_data + d, N - d);
481         fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
482         fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
483         bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
484         r_bvec.store(dX_ptr + d, N - d);
485       }
486     }
487   }
488 }
489 
490 template <typename T, typename T2>
LayerNormBackwardKernelImplInternal(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)491 void LayerNormBackwardKernelImplInternal(
492     const Tensor& dY,
493     const Tensor& X,
494     const Tensor& mean,
495     const Tensor& rstd,
496     const Tensor& gamma,
497     int64_t M,
498     int64_t N,
499     Tensor* dX,
500     Tensor* dgamma,
501     Tensor* dbeta) {
502   using opmath_t = at::opmath_type<T>;
503   TORCH_DCHECK_EQ(dY.numel(), M * N);
504   TORCH_DCHECK_EQ(X.numel(), M * N);
505   TORCH_DCHECK_EQ(mean.numel(), M);
506   TORCH_DCHECK_EQ(rstd.numel(), M);
507   DCHECK(!gamma.defined() || gamma.numel() == N);
508   const T* dY_data = dY.template const_data_ptr<T>();
509   const T* X_data = X.template const_data_ptr<T>();
510   const T2* mean_data = mean.template const_data_ptr<T2>();
511   const T2* rstd_data = rstd.template const_data_ptr<T2>();
512   const T2* gamma_data =
513       gamma.defined() ? gamma.template const_data_ptr<T2>() : nullptr;
514   T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
515   T2* dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
516   T2* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
517   const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
518   const bool gamma_null = gamma_data == nullptr;
519   const bool dX_null = dX_data == nullptr;
520   const bool dgamma_null = dgamma_data == nullptr;
521   const bool dbeta_null = dbeta_data == nullptr;
522 
523   // 1. Use two path parallel reduction for dgamma and dbeta:
524   //    First path: allocate an immediate buffer of size {2, max_threads, N},
525   //        dgamma_buffer = buffer[0], dbeta_buffer = buffer[1]
526   //    Parallel along dim0 and reduce dY and X along dim0 to buffer.
527   //    Second path: parallel along dim1 and reduce buffer to dgamma and dbeta.
528   //
529   // 2. Fuse first path of dgamma/dbeta with dX to reuse X[i] and dY[i] in L1
530   // cache.
531   //
532   int num_threads = at::get_num_threads();
533   Tensor buffer = at::empty({0}, X.options());
534   T* buffer_data = nullptr;
535   if (!dgamma_null || !dbeta_null) {
536     // zero the immediate buffer and skip zero dgamma and dbeta
537     buffer.resize_({2, num_threads, N}).zero_();
538     buffer_data = buffer.template data_ptr<T>();
539   }
540 
541   // First path of dgamma/dbeta and dX
542   at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
543     int tid = at::get_thread_num();
544     TORCH_CHECK(
545         tid < num_threads,
546         "expect thread id smaller than ",
547         num_threads,
548         ", got thread id ",
549         tid);
550     T* dgamma_buffer_ptr = dgamma_null ? nullptr : buffer_data + tid * N;
551     T* dbeta_buffer_ptr =
552         dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
553     for (const auto i : c10::irange(start, end)) {
554       layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
555     }
556   });
557 
558   // Second path of dgamma/dbeta
559   if (buffer_data != nullptr) {
560     parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
561       for (const auto j : c10::irange(start, end)) {
562         opmath_t dgamma_v = opmath_t(0);
563         opmath_t dbeta_v = opmath_t(0);
564         for (const auto i : c10::irange(num_threads)) {
565           dgamma_v += buffer_data[i * N + j];
566           dbeta_v += buffer_data[num_threads * N + i * N + j];
567         }
568         if (!dgamma_null) {
569           // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
570           dgamma_data[j] = dgamma_v;
571         }
572         if (!dbeta_null) {
573           // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
574           dbeta_data[j] = dbeta_v;
575         }
576       }
577     });
578   }
579 }
580 
LayerNormBackwardKernelImpl(const Tensor & dY,const Tensor & X,const Tensor & mean,const Tensor & rstd,const Tensor & gamma,int64_t M,int64_t N,Tensor * dX,Tensor * dgamma,Tensor * dbeta)581 void LayerNormBackwardKernelImpl(
582     const Tensor& dY,
583     const Tensor& X,
584     const Tensor& mean,
585     const Tensor& rstd,
586     const Tensor& gamma,
587     int64_t M,
588     int64_t N,
589     Tensor* dX,
590     Tensor* dgamma,
591     Tensor* dbeta) {
592   if (at::isReducedFloatingType(X.scalar_type())) {
593     AT_DISPATCH_REDUCED_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
594       if (gamma.scalar_type() == at::kFloat) {
595         LayerNormBackwardKernelImplInternal<scalar_t, float>(
596             dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
597       } else {
598         LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
599             dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
600       }
601       });
602   } else {
603     AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
604       LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
605           dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
606     });
607   }
608 }
609 
610 } // namespace
611 
612 REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl);
613 REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl);
614 
615 } // namespace at::native
616