xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/moments_utils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // Slightly modified version of caffe2/aten/src/ATen/native/cpu/moments_utils.h
12 // for use in optimized ExecuTorch ops. Template specializations of BFloat16
13 // are excluded.
14 
15 #include <executorch/kernels/optimized/vec/vec.h>
16 
17 #include <executorch/kernels/optimized/utils/math_utils.h>
18 #include <executorch/runtime/platform/compiler.h>
19 #include <array>
20 
21 namespace torch {
22 namespace executor {
23 namespace native {
24 
25 template <typename T>
26 using acc_t = executorch::utils::compute_dtype<T>;
27 
28 constexpr int64_t kChunkSize = 16;
29 
30 template <typename T>
AddMoments(int64_t m0_add,const T & m1_add,const T & m2_add,int64_t & m0,T & m1,T & m2)31 void AddMoments(
32     int64_t m0_add,
33     const T& m1_add,
34     const T& m2_add,
35     int64_t& m0,
36     T& m1,
37     T& m2) {
38   const int64_t n = m0 + m0_add;
39   const T c =
40       n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
41   const T delta = m1_add - m1;
42   m1 += c * delta;
43   m2 += m2_add + delta * delta * c * static_cast<T>(m0);
44   m0 = n;
45 }
46 
47 template <typename T>
AddMomentsVec(int64_t m0_add,const executorch::vec::Vectorized<T> & m1_add,const executorch::vec::Vectorized<T> & m2_add,int64_t & m0,executorch::vec::Vectorized<T> & m1,executorch::vec::Vectorized<T> & m2)48 ET_INLINE void AddMomentsVec(
49     int64_t m0_add,
50     const executorch::vec::Vectorized<T>& m1_add,
51     const executorch::vec::Vectorized<T>& m2_add,
52     int64_t& m0,
53     executorch::vec::Vectorized<T>& m1,
54     executorch::vec::Vectorized<T>& m2) {
55   using Vec = executorch::vec::Vectorized<T>;
56   const int64_t n = m0 + m0_add;
57   const T c =
58       n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
59   const Vec c_vec(c);
60   const Vec delta = m1_add - m1;
61   m1 += c_vec * delta;
62   m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
63   m0 = n;
64 }
65 
66 template <typename T>
UpdateMomentsVec(int64_t m0,const T * X_ptr,const std::array<executorch::vec::Vectorized<acc_t<T>>,kChunkSize> & c_vecs,int64_t & m0_stk0,executorch::vec::Vectorized<acc_t<T>> & m1_stk0,executorch::vec::Vectorized<acc_t<T>> & m2_stk0)67 inline void UpdateMomentsVec(
68     int64_t m0,
69     const T* X_ptr,
70     const std::array<executorch::vec::Vectorized<acc_t<T>>, kChunkSize>& c_vecs,
71     int64_t& m0_stk0,
72     executorch::vec::Vectorized<acc_t<T>>& m1_stk0,
73     executorch::vec::Vectorized<acc_t<T>>& m2_stk0) {
74   using Vec = executorch::vec::Vectorized<acc_t<T>>;
75   Vec m1_vec(0);
76   Vec m2_vec(0);
77   for (int64_t j = 0; j < m0; ++j) {
78     const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
79     const Vec delta_vec = x_vec - m1_vec;
80     m1_vec += delta_vec * c_vecs[j];
81     m2_vec += delta_vec * (x_vec - m1_vec);
82   }
83   AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
84 }
85 
86 // Compute rowwise moments by parallel Welford algorithm and cascade sum to
87 // improve numerical stability.
88 // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
89 // https://en.wikipedia.org/wiki/Pairwise_summation
90 template <typename T, int64_t kMaxDepth>
91 std::pair<acc_t<T>, acc_t<T>>
92 RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
93   using T_ACC = acc_t<T>;
94 
95   constexpr int64_t kVecSize = executorch::vec::Vectorized<T>::size();
96   constexpr int64_t kAccVecSize = executorch::vec::Vectorized<T_ACC>::size();
97   const int64_t n = N / kVecSize;
98   const int64_t m = executorch::utils::divup(n, kChunkSize);
99   const int64_t depth = executorch::utils::CeilLog2(m);
100 
101   using Vec = executorch::vec::Vectorized<T_ACC>;
102   const Vec kZeroVec(T_ACC(0));
103   std::array<int64_t, kMaxDepth> m0_stk;
104   std::array<Vec, kMaxDepth> m1_stk;
105   std::array<Vec, kMaxDepth> m2_stk;
106   for (int64_t i = 0; i < kMaxDepth; ++i) {
107     m0_stk[i] = 0;
108     m1_stk[i] = kZeroVec;
109     m2_stk[i] = kZeroVec;
110   }
111 
112   for (int64_t i = 0; i < m; ++i) {
113     const T* X_ptr = X + i * kChunkSize * kVecSize;
114     const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
115     static std::array<Vec, kChunkSize> c_vecs = ([]() {
116       std::array<Vec, kChunkSize> result;
117       for (int64_t j = 0; j < kChunkSize; ++j) {
118         result[j] = Vec(T_ACC(1) / static_cast<T_ACC>(j + 1));
119       }
120       return result;
121     })();
122     UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
123 
124     int64_t mask = i + 1;
125     for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
126       AddMomentsVec(
127           m0_stk[j - 1],
128           m1_stk[j - 1],
129           m2_stk[j - 1],
130           m0_stk[j],
131           m1_stk[j],
132           m2_stk[j]);
133       m0_stk[j - 1] = 0;
134       m1_stk[j - 1] = kZeroVec;
135       m2_stk[j - 1] = kZeroVec;
136       mask >>= 1;
137     }
138   }
139   for (int64_t i = 1; i < depth; ++i) {
140     AddMomentsVec(
141         m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
142   }
143 
144   std::array<T_ACC, kAccVecSize> m1_arr{};
145   std::array<T_ACC, kAccVecSize> m2_arr{};
146   m1_stk[0].store(m1_arr.data());
147   m2_stk[0].store(m2_arr.data());
148 
149   int64_t m0 = 0;
150   T_ACC m1 = 0;
151   T_ACC m2 = 0;
152   for (int64_t i = n * kVecSize; i < N; ++i) {
153     T_ACC x = static_cast<T_ACC>(X[i]);
154     const T_ACC delta = x - m1;
155     ++m0;
156     m1 += delta / static_cast<T_ACC>(m0);
157     m2 += delta * (x - m1);
158   }
159   // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
160   int64_t m0_add = n * kVecSize / kAccVecSize;
161   for (int64_t i = 0; i < kAccVecSize; ++i) {
162     AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
163   }
164 
165   return std::make_pair(m1, m2 / static_cast<T_ACC>(N - ddof));
166 }
167 
168 template <typename T>
169 std::pair<acc_t<T>, acc_t<T>>
170 RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
171   using Vec = executorch::vec::Vectorized<T>;
172   constexpr int64_t kVecSize = Vec::size();
173   const int64_t n = N / kVecSize;
174   const int64_t m = executorch::utils::divup(n, kChunkSize);
175   const int64_t depth = executorch::utils::CeilLog2(m);
176   if (depth <= 4) {
177     return RowwiseMomentsImpl<T, 4>(X, N, ddof);
178   } else if (depth <= 8) {
179     return RowwiseMomentsImpl<T, 8>(X, N, ddof);
180   } else if (depth <= 16) {
181     return RowwiseMomentsImpl<T, 16>(X, N, ddof);
182   } else if (depth <= 32) {
183     return RowwiseMomentsImpl<T, 32>(X, N, ddof);
184   } else {
185     return RowwiseMomentsImpl<T, 64>(X, N, ddof);
186   }
187 }
188 
189 } // namespace native
190 } // namespace executor
191 } // namespace torch
192