xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Lerp.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/TensorIterator.h>
6 #include <c10/core/Scalar.h>
7 
8 namespace at::native {
9 
10 template <typename scalar_t>
is_lerp_weight_small(scalar_t weight)11 C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
12   return std::abs(weight) < scalar_t(0.5);
13 }
14 template <typename scalar_t>
is_lerp_weight_small(c10::complex<scalar_t> weight)15 C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex<scalar_t> weight) {
16   // Avoid the sqrt in abs(weight)
17   return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
18 }
19 
20 template <typename scalar_t, typename weight_t>
lerp(scalar_t self_,scalar_t end_,weight_t weight_)21 C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
22   using opmath_t = at::opmath_type<scalar_t>;
23   using opmath_weight_t = at::opmath_type<weight_t>;
24 
25   opmath_t self = self_;
26   opmath_t end = end_;
27   opmath_weight_t weight = weight_;
28 
29   // Conditional for better numeric. This has been discussed in
30   // https://github.com/pytorch/pytorch/pull/18871
31   return is_lerp_weight_small(weight)
32       ? self + weight * (end - self)
33       : end - (end - self) * (opmath_t(1) - weight);
34 }
35 
36 using lerp_fn_scalar = void (*)(
37     at::TensorIteratorBase& iter,
38     const Scalar& weight);
39 
40 using lerp_fn_tensor = void (*)(
41     at::TensorIteratorBase& iter);
42 
43 DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
44 DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);
45 
46 } // namespace at::native
47