1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <c10/util/accumulate.h>
6
7 namespace at::native {
8
9 namespace {
10
_check_layer_norm_inputs(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias)11 C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
12 const Tensor& input,
13 IntArrayRef normalized_shape,
14 const Tensor& weight /* optional */,
15 const Tensor& bias /* optional */) {
16
17 const int normalized_ndim = normalized_shape.size();
18 TORCH_CHECK(
19 normalized_ndim >= 1,
20 "Expected normalized_shape to be at least 1-dimensional, i.e., ",
21 "containing at least one element, but got normalized_shape = ",
22 normalized_shape);
23 TORCH_CHECK(
24 !weight.defined() || weight.sizes().equals(normalized_shape),
25 "Expected weight to be of same shape as normalized_shape, but got ",
26 "weight of shape ",
27 weight.sizes(),
28 " and normalized_shape = ",
29 normalized_shape);
30 TORCH_CHECK(
31 !bias.defined() || bias.sizes().equals(normalized_shape),
32 "Expected bias to be of same shape as normalized_shape, but got ",
33 "bias of shape ",
34 bias.sizes(),
35 " and normalized_shape = ",
36 normalized_shape);
37
38 const auto input_shape = input.sizes();
39 const auto input_ndim = input.dim();
40
41 if (input_ndim < normalized_ndim ||
42 !input_shape.slice(input_ndim - normalized_ndim)
43 .equals(normalized_shape)) {
44 std::stringstream ss;
45 ss << "Given normalized_shape=" << normalized_shape
46 << ", expected input with shape [*";
47 for (auto size : normalized_shape) {
48 ss << ", " << size;
49 }
50 ss << "], but got input of size" << input_shape;
51 AT_ERROR(ss.str());
52 }
53
54 const int axis = input_ndim - normalized_ndim;
55 const int64_t M =
56 c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
57 const int64_t N =
58 c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
59
60 return std::make_pair(M, N);
61 }
62
63 } // namespace
64
65 void layer_norm_cpu_out(
66 at::Tensor& out,
67 const at::Tensor& input,
68 const Tensor& gamma,
69 const Tensor& beta,
70 double eps,
71 int64_t M,
72 int64_t N);
73
74 Tensor rms_norm(
75 const Tensor& input,
76 IntArrayRef normalized_shape,
77 const std::optional<Tensor>& weight_opt /* optional */,
78 std::optional<double> eps);
79
80 using forward_fn = void (*)(
81 const Tensor& /* X */,
82 const Tensor& /* gamma */,
83 const Tensor& /* beta */,
84 int64_t /* M */,
85 int64_t /* N */,
86 double /* eps */,
87 Tensor* /* Y */,
88 Tensor* /* mean */,
89 Tensor* /* rstd */);
90
91 using backward_fn = void (*)(
92 const Tensor& /* dY */,
93 const Tensor& /* X */,
94 const Tensor& /* mean */,
95 const Tensor& /* rstd */,
96 const Tensor& /* gamma */,
97 int64_t /* M */,
98 int64_t /* N */,
99 Tensor* /* dX */,
100 Tensor* /* dgamma */,
101 Tensor* /* dbeta */);
102
103 DECLARE_DISPATCH(forward_fn, LayerNormKernel);
104 DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel);
105
106 } // namespace at::native
107