xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/layer_norm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <c10/util/accumulate.h>
6 
7 namespace at::native {
8 
9 namespace {
10 
_check_layer_norm_inputs(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias)11 C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
12     const Tensor& input,
13     IntArrayRef normalized_shape,
14     const Tensor& weight /* optional */,
15     const Tensor& bias /* optional */) {
16 
17   const int normalized_ndim = normalized_shape.size();
18   TORCH_CHECK(
19       normalized_ndim >= 1,
20       "Expected normalized_shape to be at least 1-dimensional, i.e., ",
21       "containing at least one element, but got normalized_shape = ",
22       normalized_shape);
23   TORCH_CHECK(
24       !weight.defined() || weight.sizes().equals(normalized_shape),
25       "Expected weight to be of same shape as normalized_shape, but got ",
26       "weight of shape ",
27       weight.sizes(),
28       " and normalized_shape = ",
29       normalized_shape);
30   TORCH_CHECK(
31       !bias.defined() || bias.sizes().equals(normalized_shape),
32       "Expected bias to be of same shape as normalized_shape, but got ",
33       "bias of shape ",
34       bias.sizes(),
35       " and normalized_shape = ",
36       normalized_shape);
37 
38   const auto input_shape = input.sizes();
39   const auto input_ndim = input.dim();
40 
41   if (input_ndim < normalized_ndim ||
42       !input_shape.slice(input_ndim - normalized_ndim)
43            .equals(normalized_shape)) {
44     std::stringstream ss;
45     ss << "Given normalized_shape=" << normalized_shape
46        << ", expected input with shape [*";
47     for (auto size : normalized_shape) {
48       ss << ", " << size;
49     }
50     ss << "], but got input of size" << input_shape;
51     AT_ERROR(ss.str());
52   }
53 
54   const int axis = input_ndim - normalized_ndim;
55   const int64_t M =
56       c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
57   const int64_t N =
58       c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
59 
60   return std::make_pair(M, N);
61 }
62 
63 } // namespace
64 
65 void layer_norm_cpu_out(
66     at::Tensor& out,
67     const at::Tensor& input,
68     const Tensor& gamma,
69     const Tensor& beta,
70     double eps,
71     int64_t M,
72     int64_t N);
73 
74 Tensor rms_norm(
75     const Tensor& input,
76     IntArrayRef normalized_shape,
77     const std::optional<Tensor>& weight_opt /* optional */,
78     std::optional<double> eps);
79 
80 using forward_fn = void (*)(
81     const Tensor& /* X */,
82     const Tensor& /* gamma */,
83     const Tensor& /* beta */,
84     int64_t /* M */,
85     int64_t /* N */,
86     double /* eps */,
87     Tensor* /* Y */,
88     Tensor* /* mean */,
89     Tensor* /* rstd */);
90 
91 using backward_fn = void (*)(
92     const Tensor& /* dY */,
93     const Tensor& /* X */,
94     const Tensor& /* mean */,
95     const Tensor& /* rstd */,
96     const Tensor& /* gamma */,
97     int64_t /* M */,
98     int64_t /* N */,
99     Tensor* /* dX */,
100     Tensor* /* dgamma */,
101     Tensor* /* dbeta */);
102 
103 DECLARE_DISPATCH(forward_fn, LayerNormKernel);
104 DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel);
105 
106 } // namespace at::native
107