1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/Lerp.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/native/cpu/Loops.h>
6
7 #include <c10/util/irange.h>
8
9 namespace at {
10 namespace native {
11 namespace {
12
13 template <typename scalar_t>
is_lerp_weight_small(Vectorized<scalar_t> weight)14 Vectorized<scalar_t> is_lerp_weight_small(Vectorized<scalar_t> weight) {
15 static_assert(!c10::is_complex<scalar_t>::value, "");
16 return weight.abs() < Vectorized<scalar_t>(0.5);
17 }
18
19 // is_lerp_weight_small doesn't work for complex because z.abs() returns a
20 // complex vector which can't be compared. Either implement it with z.abs_2_(),
21 // or fallback to the scalar function.
22 #if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER))
23 template <typename value_t>
is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight)24 Vectorized<c10::complex<value_t>> is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight) {
25 using vec_reg_t = decltype(weight.abs_2_());
26 vec_reg_t mask = Vectorized<value_t>(weight.abs_2_()) < Vectorized<value_t>(0.25);
27 return Vectorized<c10::complex<value_t>>(mask);
28 }
29 #else
30 template <typename scalar_t>
lerp_vec_map(Vectorized<scalar_t> start,Vectorized<scalar_t> end,Vectorized<scalar_t> weight)31 Vectorized<scalar_t> lerp_vec_map(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
32 using vec_t = Vectorized<scalar_t>;
33 __at_align__ scalar_t start_arr[vec_t::size()];
34 __at_align__ scalar_t end_arr[vec_t::size()];
35 __at_align__ scalar_t weight_arr[vec_t::size()];
36 __at_align__ scalar_t result_arr[vec_t::size()];
37
38 start.store(start_arr);
39 end.store(end_arr);
40 weight.store(weight_arr);
41
42 for (auto i : c10::irange(vec_t::size())) {
43 result_arr[i] = lerp(start_arr[i], end_arr[i], weight_arr[i]);
44 }
45 return vec_t::loadu(result_arr);
46 }
47
48 template <typename value_t>
lerp_vec(Vectorized<c10::complex<value_t>> start,Vectorized<c10::complex<value_t>> end,Vectorized<c10::complex<value_t>> weight)49 Vectorized<c10::complex<value_t>> lerp_vec(Vectorized<c10::complex<value_t>> start, Vectorized<c10::complex<value_t>> end, Vectorized<c10::complex<value_t>> weight) {
50 return lerp_vec_map(start, end, weight);
51 }
52 #endif
53
54 template <typename scalar_t>
lerp_vec(Vectorized<scalar_t> start,Vectorized<scalar_t> end,Vectorized<scalar_t> weight)55 Vectorized<scalar_t> lerp_vec(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
56 using vec_t = Vectorized<scalar_t>;
57 auto mask = is_lerp_weight_small(weight);
58 auto coeff = vec_t::blendv(weight - vec_t(1), weight, mask);
59 auto base = vec_t::blendv(end, start, mask);
60 return vec::fmadd(coeff, end - start, base);
61 }
62
lerp_scalar_kernel(at::TensorIteratorBase & iter,const Scalar & weight)63 void lerp_scalar_kernel(at::TensorIteratorBase& iter, const Scalar& weight) {
64 if (iter.common_dtype() == kBFloat16) {
65 using bVec = Vectorized<BFloat16>;
66 using fVec = Vectorized<float>;
67 float weight_val = weight.to<float>();
68 auto weight_vec = fVec(weight_val);
69 at::native::cpu_kernel_vec(
70 iter,
71 [weight_val](BFloat16 self_val, BFloat16 end_val) -> BFloat16 {
72 return lerp(self_val, end_val, weight_val);
73 },
74 [=](bVec self_vec, bVec end_vec) -> bVec {
75 auto [self_vec0, self_vec1] = convert_bfloat16_float(self_vec);
76 auto [end_vec0, end_vec1] = convert_bfloat16_float(end_vec);
77 auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
78 auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
79 return convert_float_bfloat16(result0, result1);
80 });
81 } else if (iter.common_dtype() == kHalf) {
82 using hVec = Vectorized<Half>;
83 using fVec = Vectorized<float>;
84 float weight_val = weight.to<float>();
85 auto weight_vec = fVec(weight_val);
86 at::native::cpu_kernel_vec(
87 iter,
88 [weight_val](Half self_val, Half end_val) -> Half {
89 return lerp(self_val, end_val, weight_val);
90 },
91 [=](hVec self_vec, hVec end_vec) -> hVec {
92 auto [self_vec0, self_vec1] = convert_half_float(self_vec);
93 auto [end_vec0, end_vec1] = convert_half_float(end_vec);
94 auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
95 auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
96 return convert_float_half(result0, result1);
97 });
98 } else {
99 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] {
100 auto weight_val = weight.to<scalar_t>();
101 at::native::cpu_kernel_vec(
102 iter,
103 [weight_val](scalar_t self_val, scalar_t end_val) {
104 return lerp(self_val, end_val, weight_val);
105 },
106 [weight_val](Vectorized<scalar_t> self, Vectorized<scalar_t> end) {
107 const Vectorized<scalar_t> weight(weight_val);
108 return lerp_vec(self, end, weight);
109 });
110 });
111 }
112 }
113
lerp_tensor_kernel(at::TensorIteratorBase & iter)114 void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
115 if (iter.common_dtype() == kBFloat16) {
116 using bVec = Vectorized<BFloat16>;
117 at::native::cpu_kernel_vec(
118 iter,
119 [=](BFloat16 self_val, BFloat16 end_val, BFloat16 weight_val) -> BFloat16 {
120 return lerp(self_val, end_val, weight_val);
121 },
122 [=](bVec self_vec, bVec end_vec, bVec weight_vec) -> bVec {
123 auto [self_vec0, self_vec1] = convert_bfloat16_float(self_vec);
124 auto [end_vec0, end_vec1] = convert_bfloat16_float(end_vec);
125 auto [weight_vec0, weight_vec1] = convert_bfloat16_float(weight_vec);
126 auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
127 auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
128 return convert_float_bfloat16(result0, result1);
129 });
130 } else if (iter.common_dtype() == kHalf) {
131 using hVec = Vectorized<Half>;
132 at::native::cpu_kernel_vec(
133 iter,
134 [=](Half self_val, Half end_val, Half weight_val) -> Half {
135 return lerp(self_val, end_val, weight_val);
136 },
137 [=](hVec self_vec, hVec end_vec, hVec weight_vec) -> hVec {
138 auto [self_vec0, self_vec1] = convert_half_float(self_vec);
139 auto [end_vec0, end_vec1] = convert_half_float(end_vec);
140 auto [weight_vec0, weight_vec1] = convert_half_float(weight_vec);
141 auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
142 auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
143 return convert_float_half(result0, result1);
144 });
145 } else {
146 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] {
147 at::native::cpu_kernel_vec(
148 iter,
149 [](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
150 return lerp(self_val, end_val, weight_val);
151 },
152 [](Vectorized<scalar_t> self_val, Vectorized<scalar_t> end_val, Vectorized<scalar_t> weight_val) {
153 return lerp_vec(self_val, end_val, weight_val);
154 });
155 });
156 }
157 }
158
159 } // anonymous namespace
160
161 REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_scalar_kernel);
162 REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_kernel);
163
164 } // namespace native
165 } // namespace at
166