xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/functional_bfloat16.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5 
6 #include <ATen/cpu/vec/vec.h>
7 
8 namespace at::vec {
9 
10 // BFloat16 specification
11 template <typename scalar_t> struct VecScalarType { using type = scalar_t; };
12 template <> struct VecScalarType<BFloat16> { using type = float; };
13 template <> struct VecScalarType<Half> { using type = float; };
14 
15 // This is different from at::acc_type since we only need to specialize BFloat16
16 template <typename scalar_t>
17 using vec_scalar_t = typename VecScalarType<scalar_t>::type;
18 
19 // Vector conversion between float and bfloat16/half
20 template <typename scalar_t,
21           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
22 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(const Vectorized<scalar_t>&);
23 
24 template <>
25 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<BFloat16> (const Vectorized<BFloat16>& a) {
26   return convert_bfloat16_float(a);
27 }
28 
29 template <>
30 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half> (const Vectorized<Half>& a) {
31     return convert_half_float(a);
32 }
33 
34 template <typename scalar_t,
35           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
36 inline Vectorized<scalar_t> convert_from_float(const Vectorized<float>&, const Vectorized<float>&);
37 
38 template <>
39 inline Vectorized<BFloat16> convert_from_float<BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
40   return convert_float_bfloat16(a, b);
41 }
42 
43 template <>
44 inline Vectorized<Half> convert_from_float<Half>(const Vectorized<float>& a, const Vectorized<float>& b) {
45   return convert_float_half(a, b);
46 }
47 
48 template <typename scalar_t,
49           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
50 inline void load_to_float(const scalar_t *data, Vectorized<float> &out1, Vectorized<float> &out2);
51 
52 template <>
53 inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out1, Vectorized<float> &out2) {
54   load_fp32_from_bf16(data, out1, out2);
55 }
56 
57 template <>
58 inline void load_to_float<Half> (const Half *data, Vectorized<float> &out1, Vectorized<float> &out2) {
59   load_fp32_from_fp16(data, out1, out2);
60 }
61 
62 template <typename scalar_t,
63           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
64 inline void load_to_float(const scalar_t *data, Vectorized<float> &out);
65 
66 template <>
67 inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out) {
68   load_fp32_from_bf16(data, out);
69 }
70 
71 template <>
72 inline void load_to_float<Half> (const Half *data, Vectorized<float> &out) {
73   load_fp32_from_fp16(data, out);
74 }
75 
76 // Note that we already have specialized member of Vectorized<scalar_t> for BFloat16
77 // so the following functions would run smoothly:
78 //   using Vec = Vectorized<BFloat16>;
79 //   Vec one = Vec(BFloat16(1));
80 //   vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
81 //
82 // Then why we still need to specialize "functional"?
83 //   If we do specialization at Vectorized<> level, the above example would need 3 pairs of
84 //   conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
85 //   If we do specialization at vec::map<>() level, we have only 1 pair of conversion
86 //   of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
87 //
88 // The following BFloat16 functionality will only do data type conversion for input
89 // and output vector (reduce functionality will only convert the final scalar back to bf16).
90 // Compared to Vectorized<> specialization,
91 //   1. better performance since we have less data type conversion;
92 //   2. less rounding error since immediate results are kept in fp32;
93 //   3. accumulation done on data type of fp32.
94 //
95 //  If you plan to extend this file, please ensure adding unit tests at
96 //    aten/src/ATen/test/vec_test_all_types.cpp
97 //
98 template <typename scalar_t, typename Op,
99           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
100 inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
101   using bVec = vec::Vectorized<scalar_t>;
102   using fVec = vec::Vectorized<float>;
103   if (size < bVec::size()) {
104     bVec data_bvec = bVec::loadu(data, size);
105     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
106     if (size > fVec::size()) {
107       data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
108       return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
109     } else {
110       return vec_reduce_all<float>(vec_fun, data_fvec0, size);
111     }
112   }
113   int64_t d = bVec::size();
114   bVec acc_bvec = bVec::loadu(data);
115   auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
116   for (; d < size - (size % bVec::size()); d += bVec::size()) {
117     bVec data_bvec = bVec::loadu(data + d);
118     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
119     acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
120     acc_fvec1 = vec_fun(acc_fvec1, data_fvec1);
121   }
122   if (size - d > 0) {
123     bVec data_bvec = bVec::loadu(data + d, size - d);
124     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
125     if (size - d > fVec::size()) {
126       acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
127       acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
128     } else {
129       acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
130     }
131   }
132   acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
133   return vec_reduce_all<float>(vec_fun, acc_fvec0);
134 }
135 
136 template <typename scalar_t, typename Op1, typename Op2,
137           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
138 inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
139     const scalar_t* data, int64_t size) {
140   using bVec = vec::Vectorized<scalar_t>;
141   using fVec = vec::Vectorized<float>;
142   if (size < bVec::size()) {
143     bVec data_bvec = bVec::loadu(data, size);
144     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
145     if (size > fVec::size()) {
146       fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
147       fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
148       return std::pair<scalar_t, scalar_t>(
149           vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
150           vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
151     } else {
152       return std::pair<scalar_t, scalar_t>(
153           vec_reduce_all<float>(vec_fun1, data_fvec0, size),
154           vec_reduce_all<float>(vec_fun2, data_fvec0, size));
155     }
156   }
157   int64_t d = bVec::size();
158   bVec acc_bvec = bVec::loadu(data);
159   auto [acc1_fvec0, acc1_fvec1] = convert_to_float<scalar_t>(acc_bvec);
160   auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc_bvec);
161   for (; d < size - (size % bVec::size()); d += bVec::size()) {
162     bVec data_bvec = bVec::loadu(data + d);
163     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
164     acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
165     acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1);
166     acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
167     acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1);
168   }
169   if (size - d > 0) {
170     bVec data_bvec = bVec::loadu(data + d, size - d);
171     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
172     if (size - d > fVec::size()) {
173       acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
174       acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size());
175       acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
176       acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size());
177     } else {
178       acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
179       acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
180     }
181   }
182   acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
183   acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1);
184   return std::pair<scalar_t, scalar_t>(
185       vec_reduce_all<float>(vec_fun1, acc1_fvec0),
186       vec_reduce_all<float>(vec_fun2, acc2_fvec0));
187 }
188 
189 template <typename scalar_t, typename MapOp, typename ReduceOp,
190           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
191 inline float map_reduce_all(
192     const MapOp& map_fun,
193     const ReduceOp& red_fun,
194     const scalar_t* data,
195     int64_t size) {
196   using bVec = vec::Vectorized<scalar_t>;
197   using fVec = vec::Vectorized<float>;
198   if (size < bVec::size()) {
199     bVec data_bvec = bVec::loadu(data, size);
200     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
201     if (size > fVec::size()) {
202       data_fvec0 = map_fun(data_fvec0);
203       data_fvec1 = map_fun(data_fvec1);
204       data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
205       return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
206     } else {
207       data_fvec0 = map_fun(data_fvec0);
208       return vec_reduce_all<float>(red_fun, data_fvec0, size);
209     }
210   }
211   int64_t d = bVec::size();
212   bVec acc_bvec = bVec::loadu(data);
213   auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
214   acc_fvec0 = map_fun(acc_fvec0);
215   acc_fvec1 = map_fun(acc_fvec1);
216   for (; d < size - (size % bVec::size()); d += bVec::size()) {
217     bVec data_bvec = bVec::loadu(data + d);
218     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
219     data_fvec0 = map_fun(data_fvec0);
220     data_fvec1 = map_fun(data_fvec1);
221     acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
222     acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
223   }
224   if (size - d > 0) {
225     bVec data_bvec = bVec::loadu(data + d, size - d);
226     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
227     if (size - d > fVec::size()) {
228       data_fvec0 = map_fun(data_fvec0);
229       data_fvec1 = map_fun(data_fvec1);
230       acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
231       acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
232     } else {
233       data_fvec0 = map_fun(data_fvec0);
234       acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
235     }
236   }
237   acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
238   return vec_reduce_all<float>(red_fun, acc_fvec0);
239 }
240 
241 template <typename scalar_t, typename MapOp, typename ReduceOp,
242           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
243 inline float map2_reduce_all(
244     const MapOp& map_fun,
245     const ReduceOp& red_fun,
246     const scalar_t* data,
247     const scalar_t* data2,
248     int64_t size) {
249   using bVec = vec::Vectorized<scalar_t>;
250   using fVec = vec::Vectorized<float>;
251   if (size < bVec::size()) {
252     bVec data_bvec = bVec::loadu(data, size);
253     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
254     bVec data2_bvec = bVec::loadu(data2, size);
255     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
256     if (size > fVec::size()) {
257       data_fvec0 = map_fun(data_fvec0, data2_fvec0);
258       data_fvec1 = map_fun(data_fvec1, data2_fvec1);
259       data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
260       return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
261     } else {
262       data_fvec0 = map_fun(data_fvec0, data2_fvec0);
263       return vec_reduce_all<float>(red_fun, data_fvec0, size);
264     }
265   }
266   int64_t d = bVec::size();
267   bVec acc_bvec = bVec::loadu(data);
268   auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
269   bVec acc2_bvec = bVec::loadu(data2);
270   auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
271   acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0);
272   acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1);
273   for (; d < size - (size % bVec::size()); d += bVec::size()) {
274     bVec data_bvec = bVec::loadu(data + d);
275     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
276     bVec data2_bvec = bVec::loadu(data2 + d);
277     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
278     data_fvec0 = map_fun(data_fvec0, data2_fvec0);
279     data_fvec1 = map_fun(data_fvec1, data2_fvec1);
280     acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
281     acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
282   }
283   if (size - d > 0) {
284     bVec data_bvec = bVec::loadu(data + d, size - d);
285     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
286     bVec data2_bvec = bVec::loadu(data2 + d, size - d);
287     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
288     if (size - d > fVec::size()) {
289       data_fvec0 = map_fun(data_fvec0, data2_fvec0);
290       data_fvec1 = map_fun(data_fvec1, data2_fvec1);
291       acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
292       acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
293     } else {
294       data_fvec0 = map_fun(data_fvec0, data2_fvec0);
295       acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
296     }
297   }
298   acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
299   return vec_reduce_all<float>(red_fun, acc_fvec0);
300 }
301 
302 template <typename scalar_t, typename MapOp, typename ReduceOp,
303           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
304 inline float map3_reduce_all(
305     const MapOp& map_fun,
306     const ReduceOp& red_fun,
307     const scalar_t* data,
308     const scalar_t* data2,
309     const scalar_t* data3,
310     int64_t size) {
311   using bVec = vec::Vectorized<scalar_t>;
312   using fVec = vec::Vectorized<float>;
313   if (size < bVec::size()) {
314     bVec data_bvec = bVec::loadu(data, size);
315     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
316     bVec data2_bvec = bVec::loadu(data2, size);
317     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
318     bVec data3_bvec = bVec::loadu(data3, size);
319     auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
320     if (size > fVec::size()) {
321       data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
322       data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
323       data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
324       return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
325     } else {
326       data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
327       return vec_reduce_all<float>(red_fun, data_fvec0, size);
328     }
329   }
330   int64_t d = bVec::size();
331   bVec acc_bvec = bVec::loadu(data);
332   auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
333   bVec acc2_bvec = bVec::loadu(data2);
334   auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
335   bVec acc3_bvec = bVec::loadu(data3);
336   auto [acc3_fvec0, acc3_fvec1] = convert_to_float<scalar_t>(acc3_bvec);
337   acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0);
338   acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1);
339   for (; d < size - (size % bVec::size()); d += bVec::size()) {
340     bVec data_bvec = bVec::loadu(data + d);
341     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
342     bVec data2_bvec = bVec::loadu(data2 + d);
343     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
344     bVec data3_bvec = bVec::loadu(data3 + d);
345     auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
346     data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
347     data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
348     acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
349     acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
350   }
351   if (size - d > 0) {
352     bVec data_bvec = bVec::loadu(data + d, size - d);
353     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
354     bVec data2_bvec = bVec::loadu(data2 + d, size - d);
355     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
356     bVec data3_bvec = bVec::loadu(data3 + d, size - d);
357     auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
358     if (size - d > fVec::size()) {
359       data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
360       data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
361       acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
362       acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
363     } else {
364       data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
365       acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
366     }
367   }
368   acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
369   return vec_reduce_all<float>(red_fun, acc_fvec0);
370 }
371 
372 template <typename scalar_t, typename Op,
373           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
374 inline void map(
375     const Op& vec_fun,
376     scalar_t* output_data,
377     const scalar_t* input_data,
378     int64_t size) {
379   using bVec = vec::Vectorized<scalar_t>;
380   using fVec = vec::Vectorized<float>;
381   int64_t d = 0;
382   for (; d < size - (size % bVec::size()); d += bVec::size()) {
383     bVec data_bvec = bVec::loadu(input_data + d);
384     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
385     fVec output_fvec0 = vec_fun(data_fvec0);
386     fVec output_fvec1 = vec_fun(data_fvec1);
387     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
388     output_bvec.store(output_data + d);
389   }
390   if (size - d > 0) {
391     bVec data_bvec = bVec::loadu(input_data + d, size - d);
392     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
393     fVec output_fvec0 = vec_fun(data_fvec0);
394     fVec output_fvec1 = vec_fun(data_fvec1);
395     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
396     output_bvec.store(output_data + d, size - d);
397   }
398 }
399 
400 template <typename scalar_t, typename Op,
401           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
402 inline void map(
403     const Op& vec_fun,
404     scalar_t* output_data,
405     const float* input_data,
406     int64_t size) {
407   using bVec = vec::Vectorized<scalar_t>;
408   using fVec = vec::Vectorized<float>;
409   int64_t d = 0;
410   for (; d < size - (size % bVec::size()); d += bVec::size()) {
411     fVec data_fvec0 = fVec::loadu(input_data + d);
412     fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size());
413     fVec output_fvec0 = vec_fun(data_fvec0);
414     fVec output_fvec1 = vec_fun(data_fvec1);
415     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
416     output_bvec.store(output_data + d);
417   }
418   if (size - d > 0) {
419     fVec data_fvec0, data_fvec1;
420     if (size - d > fVec::size()) {
421       data_fvec0 = fVec::loadu(input_data + d);
422       data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
423     } else {
424       // choose to align with behaviour of bVec::loadu(ptr, size),
425       // which leaves data_fvec1 uninitialized
426       data_fvec0 = fVec::loadu(input_data + d, size - d);
427     }
428     fVec output_fvec0 = vec_fun(data_fvec0);
429     fVec output_fvec1 = vec_fun(data_fvec1);
430     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
431     output_bvec.store(output_data + d, size - d);
432   }
433 }
434 
435 template <typename scalar_t, typename Op,
436           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
437 inline void map2(
438     const Op& vec_fun,
439     scalar_t* output_data,
440     const scalar_t* input_data,
441     const scalar_t* input_data2,
442     int64_t size) {
443   using bVec = vec::Vectorized<scalar_t>;
444   using fVec = vec::Vectorized<float>;
445   int64_t d = 0;
446   for (; d < size - (size % bVec::size()); d += bVec::size()) {
447     bVec data_bvec = bVec::loadu(input_data + d);
448     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
449     bVec data2_bvec = bVec::loadu(input_data2 + d);
450     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
451     fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
452     fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
453     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
454     output_bvec.store(output_data + d);
455   }
456   if (size - d > 0) {
457     bVec data_bvec = bVec::loadu(input_data + d, size - d);
458     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
459     bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
460     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
461     fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
462     fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
463     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
464     output_bvec.store(output_data + d, size - d);
465   }
466 }
467 
468 template <typename scalar_t, typename Op,
469           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
470 inline void map3(
471     const Op& vec_fun,
472     scalar_t* output_data,
473     const scalar_t* input_data1,
474     const scalar_t* input_data2,
475     const scalar_t* input_data3,
476     int64_t size) {
477   using bVec = vec::Vectorized<scalar_t>;
478   using fVec = vec::Vectorized<float>;
479   int64_t d = 0;
480   for (; d < size - (size % bVec::size()); d += bVec::size()) {
481     bVec data1_bvec = bVec::loadu(input_data1 + d);
482     auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
483     bVec data2_bvec = bVec::loadu(input_data2 + d);
484     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
485     bVec data3_bvec = bVec::loadu(input_data3 + d);
486     auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
487     fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
488     fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
489     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
490     output_bvec.store(output_data + d);
491   }
492   if (size - d > 0) {
493     bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
494     auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
495     bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
496     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
497     bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
498     auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
499     fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
500     fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
501     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
502     output_bvec.store(output_data + d, size - d);
503   }
504 }
505 
506 template <typename scalar_t, typename Op,
507           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
508 inline void map4(
509     const Op& vec_fun,
510     scalar_t* output_data,
511     const scalar_t* input_data1,
512     const scalar_t* input_data2,
513     const scalar_t* input_data3,
514     const scalar_t* input_data4,
515     int64_t size) {
516   using bVec = vec::Vectorized<scalar_t>;
517   using fVec = vec::Vectorized<float>;
518   int64_t d = 0;
519   for (; d < size - (size % bVec::size()); d += bVec::size()) {
520     bVec data1_bvec = bVec::loadu(input_data1 + d);
521     auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
522     bVec data2_bvec = bVec::loadu(input_data2 + d);
523     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
524     bVec data3_bvec = bVec::loadu(input_data3 + d);
525     auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
526     bVec data4_bvec = bVec::loadu(input_data4 + d);
527     auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
528     fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
529     fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
530     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
531     output_bvec.store(output_data + d);
532   }
533   if (size - d > 0) {
534     bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
535     auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
536     bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
537     auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
538     bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
539     auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
540     bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
541     auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
542     fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
543     fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
544     bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
545     output_bvec.store(output_data + d, size - d);
546   }
547 }
548 
549 } // namespace at::vec
550