xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Parallel.h>
4 #include <ATen/core/TensorAccessor.h>
5 #include <ATen/cpu/vec/vec.h>
6 #include <c10/util/llvmMathExtras.h>
7 
8 #ifdef USE_FBGEMM
9 #include <fbgemm/Fbgemm.h>
10 #endif
11 
12 namespace at::native {
13 
14 template <typename T>
_store(T * dst,at::vec::Vectorized<T> src)15 inline void _store(T* dst, at::vec::Vectorized<T> src) {
16   src.store(dst);
17 }
18 
_store(at::BFloat16 * dst,at::vec::Vectorized<float> src)19 inline void _store(at::BFloat16* dst, at::vec::Vectorized<float> src) {
20   auto res = at::vec::convert_float_bfloat16(src, src);
21   res.store(dst, at::vec::Vectorized<float>::size());
22 }
23 
_store(at::Half * dst,at::vec::Vectorized<float> src)24 inline void _store(at::Half* dst, at::vec::Vectorized<float> src) {
25   auto res = at::vec::convert_float_half(src, src);
26   res.store(dst, at::vec::Vectorized<float>::size());
27 }
28 
29 inline namespace CPU_CAPABILITY {
30 
31 template <typename T>
data_index_init(T offset)32 inline T data_index_init(T offset) {
33   return offset;
34 }
35 
36 template <typename T, typename... Args>
data_index_init(T offset,T & x,const T & X,Args &&...args)37 inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
38   offset = data_index_init(offset, std::forward<Args>(args)...);
39   x = offset % X;
40   return offset / X;
41 }
42 
data_index_step()43 inline bool data_index_step() {
44   return true;
45 }
46 
47 template <typename T, typename... Args>
data_index_step(T & x,const T & X,Args &&...args)48 inline bool data_index_step(T& x, const T& X, Args&&... args) {
49   if (data_index_step(std::forward<Args>(args)...)) {
50     x = ((x + 1) == X) ? 0 : (x + 1);
51     return x == 0;
52   }
53   return false;
54 }
55 
56 // Helper struct for bfloat16/float16 vectorization
57 // Useful when you need float as immediate dtype or accumulate dtype
58 using namespace vec;
59 struct Vec2 {
60   Vectorized<float> val0, val1;
Vec2Vec261   Vec2(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
Vec2Vec262   Vec2(float v) : val0(v), val1(v) {}
loaduVec263   static Vec2 loadu(const BFloat16* ptr) {
64     auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
65     return {v0, v1};
66   }
loaduVec267   static Vec2 loadu(const Half* ptr) {
68     auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
69     return {v0, v1};
70   }
loaduVec271   static Vec2 loadu(const float* ptr) {
72     return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
73   }
storeVec274   void store(BFloat16* ptr) const {
75     Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
76     val.store(ptr);
77   }
storeVec278   void store(Half* ptr) const {
79     Vectorized<Half> val = convert_float_half(val0, val1);
80     val.store(ptr);
81   }
storeVec282   void store(float* ptr) const {
83     val0.store(ptr);
84     val1.store(ptr + Vectorized<float>::size());
85   }
86 };
87 inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
88 inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
89 inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
90 inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
maximum(const Vec2 & a,const Vec2 & b)91 inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
minimum(const Vec2 & a,const Vec2 & b)92 inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
93 
94 template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
95 template <> struct VectorizedType<BFloat16> { using type = Vec2; };
96 template <> struct VectorizedType<Half> { using type = Vec2; };
97 template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
98 
99 // Helper for mixed data type parameter Vec::load
100 inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr) {
101   return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
102 }
103 
104 inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr) {
105   return convert_half_float(Vectorized<Half>::loadu(ptr));
106 }
107 
108 inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
109   using Vec = Vectorized<float>;
110   return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
111 }
112 
113 inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr, int64_t count) {
114   return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
115 }
116 
117 inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr, int64_t count) {
118   return convert_half_float(Vectorized<Half>::loadu(ptr, count));
119 }
120 
121 inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
122   using Vec = Vectorized<float>;
123   if (count > Vec::size()) {
124   return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
125   } else {
126     return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
127   }
128 }
129 
130 } // namespace
131 
132 namespace utils {
133 
134 template <typename T>
135 T CeilLog2(const T& x) {
136   if (x <= 2) {
137     return 1;
138   }
139   // Last set bit is floor(log2(x)), floor + 1 is ceil
140   // except when x is an exact powers of 2, so subtract 1 first
141   return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
142 }
143 
144 // matrix transpose:
145 //   src has shape of M by N, with leading dimension of ld_src
146 //   dst has shape of N by M, with leading dimension of ld_dst
147 template <typename T>
148 inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
149   for (int64_t j = 0; j < N; j++) {
150     for (int64_t i = 0; i < M; i++) {
151       dst[j * ld_dst + i] = src[i * ld_src + j];
152     }
153   }
154 }
155 
156 #ifdef USE_FBGEMM
157 template <>
158 inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
159   TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
160   fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
161 }
162 
163 template <>
164 inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) {
165   TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
166   fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
167 }
168 #endif
169 
170 template <typename index_t, typename F>
171 inline void parallel_sparse_csr(
172     const TensorAccessor<index_t, 1>& crow_acc,
173     const int64_t M,
174     const int64_t nnz,
175     const F& f) {
176   TORCH_CHECK(crow_acc.size(0) == M + 1);
177 
178   // directly parallel on `M` may lead to load imbalance,
179   // statically determine thread partition here to average payload
180   // for each thread.
181   int num_threads = at::get_num_threads();
182   std::vector<int64_t> thread_splits(num_threads + 1, M);
183 
184   int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
185 
186   thread_splits[0] = 0;
187   int64_t sum = 0;
188   int64_t t = 1;
189   for (const auto m : c10::irange(M)) {
190     int64_t row_start = crow_acc[m];
191     int64_t row_end = crow_acc[m + 1];
192     sum += row_end - row_start;
193     if (sum > t * thread_averge_payload) {
194       thread_splits[t] = m;
195       t++;
196     }
197   }
198   // need to restore the last index,
199   // due to rounding error when calculating `thread_averge_payload`.
200   thread_splits[num_threads] = M;
201 
202   at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
203     int tid = at::get_thread_num();
204     int64_t begin = thread_splits[tid];
205     int64_t end = thread_splits[tid + 1];
206     f(begin, end);
207   });
208 }
209 
210 } // namespace utils
211 
212 } // namespace at::native
213