xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/LerpKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/Lerp.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/native/cpu/Loops.h>
6 
7 #include <c10/util/irange.h>
8 
9 namespace at {
10 namespace native {
11 namespace {
12 
13 template <typename scalar_t>
is_lerp_weight_small(Vectorized<scalar_t> weight)14 Vectorized<scalar_t> is_lerp_weight_small(Vectorized<scalar_t> weight) {
15   static_assert(!c10::is_complex<scalar_t>::value, "");
16   return weight.abs() < Vectorized<scalar_t>(0.5);
17 }
18 
19 // is_lerp_weight_small doesn't work for complex because z.abs() returns a
20 // complex vector which can't be compared. Either implement it with z.abs_2_(),
21 // or fallback to the scalar function.
22 #if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER))
23 template <typename value_t>
is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight)24 Vectorized<c10::complex<value_t>> is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight) {
25   using vec_reg_t = decltype(weight.abs_2_());
26   vec_reg_t mask = Vectorized<value_t>(weight.abs_2_()) < Vectorized<value_t>(0.25);
27   return Vectorized<c10::complex<value_t>>(mask);
28 }
29 #else
30 template <typename scalar_t>
lerp_vec_map(Vectorized<scalar_t> start,Vectorized<scalar_t> end,Vectorized<scalar_t> weight)31 Vectorized<scalar_t> lerp_vec_map(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
32   using vec_t = Vectorized<scalar_t>;
33   __at_align__ scalar_t start_arr[vec_t::size()];
34   __at_align__ scalar_t end_arr[vec_t::size()];
35   __at_align__ scalar_t weight_arr[vec_t::size()];
36   __at_align__ scalar_t result_arr[vec_t::size()];
37 
38   start.store(start_arr);
39   end.store(end_arr);
40   weight.store(weight_arr);
41 
42   for (auto i : c10::irange(vec_t::size())) {
43     result_arr[i] = lerp(start_arr[i], end_arr[i], weight_arr[i]);
44   }
45   return vec_t::loadu(result_arr);
46 }
47 
48 template <typename value_t>
lerp_vec(Vectorized<c10::complex<value_t>> start,Vectorized<c10::complex<value_t>> end,Vectorized<c10::complex<value_t>> weight)49 Vectorized<c10::complex<value_t>> lerp_vec(Vectorized<c10::complex<value_t>> start, Vectorized<c10::complex<value_t>> end, Vectorized<c10::complex<value_t>> weight) {
50   return lerp_vec_map(start, end, weight);
51 }
52 #endif
53 
54 template <typename scalar_t>
lerp_vec(Vectorized<scalar_t> start,Vectorized<scalar_t> end,Vectorized<scalar_t> weight)55 Vectorized<scalar_t> lerp_vec(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
56   using vec_t = Vectorized<scalar_t>;
57   auto mask = is_lerp_weight_small(weight);
58   auto coeff = vec_t::blendv(weight - vec_t(1), weight, mask);
59   auto base = vec_t::blendv(end, start, mask);
60   return vec::fmadd(coeff, end - start, base);
61 }
62 
lerp_scalar_kernel(at::TensorIteratorBase & iter,const Scalar & weight)63 void lerp_scalar_kernel(at::TensorIteratorBase& iter, const Scalar& weight) {
64   if (iter.common_dtype() == kBFloat16) {
65     using bVec = Vectorized<BFloat16>;
66     using fVec = Vectorized<float>;
67     float weight_val = weight.to<float>();
68     auto weight_vec = fVec(weight_val);
69     at::native::cpu_kernel_vec(
70       iter,
71       [weight_val](BFloat16 self_val, BFloat16 end_val) -> BFloat16 {
72         return lerp(self_val, end_val, weight_val);
73       },
74       [=](bVec self_vec, bVec end_vec) -> bVec {
75           auto [self_vec0, self_vec1] = convert_bfloat16_float(self_vec);
76           auto [end_vec0, end_vec1] = convert_bfloat16_float(end_vec);
77           auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
78           auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
79           return convert_float_bfloat16(result0, result1);
80       });
81   } else if (iter.common_dtype() == kHalf) {
82     using hVec = Vectorized<Half>;
83     using fVec = Vectorized<float>;
84     float weight_val = weight.to<float>();
85     auto weight_vec = fVec(weight_val);
86     at::native::cpu_kernel_vec(
87       iter,
88       [weight_val](Half self_val, Half end_val) -> Half {
89         return lerp(self_val, end_val, weight_val);
90       },
91       [=](hVec self_vec, hVec end_vec) -> hVec {
92           auto [self_vec0, self_vec1] = convert_half_float(self_vec);
93           auto [end_vec0, end_vec1] = convert_half_float(end_vec);
94           auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
95           auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
96           return convert_float_half(result0, result1);
97       });
98   } else {
99     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] {
100       auto weight_val = weight.to<scalar_t>();
101       at::native::cpu_kernel_vec(
102           iter,
103           [weight_val](scalar_t self_val, scalar_t end_val) {
104             return lerp(self_val, end_val, weight_val);
105           },
106           [weight_val](Vectorized<scalar_t> self, Vectorized<scalar_t> end) {
107             const Vectorized<scalar_t> weight(weight_val);
108             return lerp_vec(self, end, weight);
109           });
110     });
111   }
112 }
113 
lerp_tensor_kernel(at::TensorIteratorBase & iter)114 void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
115   if (iter.common_dtype() == kBFloat16) {
116     using bVec = Vectorized<BFloat16>;
117     at::native::cpu_kernel_vec(
118       iter,
119       [=](BFloat16 self_val, BFloat16 end_val, BFloat16 weight_val) -> BFloat16 {
120         return lerp(self_val, end_val, weight_val);
121       },
122       [=](bVec self_vec, bVec end_vec, bVec weight_vec) -> bVec {
123           auto [self_vec0, self_vec1] = convert_bfloat16_float(self_vec);
124           auto [end_vec0, end_vec1] = convert_bfloat16_float(end_vec);
125           auto [weight_vec0, weight_vec1] = convert_bfloat16_float(weight_vec);
126           auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
127           auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
128           return convert_float_bfloat16(result0, result1);
129       });
130   } else if (iter.common_dtype() == kHalf) {
131     using hVec = Vectorized<Half>;
132     at::native::cpu_kernel_vec(
133       iter,
134       [=](Half self_val, Half end_val, Half weight_val) -> Half {
135         return lerp(self_val, end_val, weight_val);
136       },
137       [=](hVec self_vec, hVec end_vec, hVec weight_vec) -> hVec {
138           auto [self_vec0, self_vec1] = convert_half_float(self_vec);
139           auto [end_vec0, end_vec1] = convert_half_float(end_vec);
140           auto [weight_vec0, weight_vec1] = convert_half_float(weight_vec);
141           auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
142           auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
143           return convert_float_half(result0, result1);
144       });
145   } else {
146     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] {
147       at::native::cpu_kernel_vec(
148           iter,
149           [](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
150             return lerp(self_val, end_val, weight_val);
151           },
152           [](Vectorized<scalar_t> self_val, Vectorized<scalar_t> end_val, Vectorized<scalar_t> weight_val) {
153             return lerp_vec(self_val, end_val, weight_val);
154           });
155     });
156   }
157 }
158 
159 } // anonymous namespace
160 
161 REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_scalar_kernel);
162 REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_kernel);
163 
164 } // namespace native
165 } // namespace at
166