1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <ATen/OpMathType.h> 5 #include <ATen/TensorIterator.h> 6 #include <c10/core/Scalar.h> 7 8 namespace at::native { 9 10 template <typename scalar_t> is_lerp_weight_small(scalar_t weight)11C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) { 12 return std::abs(weight) < scalar_t(0.5); 13 } 14 template <typename scalar_t> is_lerp_weight_small(c10::complex<scalar_t> weight)15C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex<scalar_t> weight) { 16 // Avoid the sqrt in abs(weight) 17 return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25); 18 } 19 20 template <typename scalar_t, typename weight_t> lerp(scalar_t self_,scalar_t end_,weight_t weight_)21C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) { 22 using opmath_t = at::opmath_type<scalar_t>; 23 using opmath_weight_t = at::opmath_type<weight_t>; 24 25 opmath_t self = self_; 26 opmath_t end = end_; 27 opmath_weight_t weight = weight_; 28 29 // Conditional for better numeric. This has been discussed in 30 // https://github.com/pytorch/pytorch/pull/18871 31 return is_lerp_weight_small(weight) 32 ? self + weight * (end - self) 33 : end - (end - self) * (opmath_t(1) - weight); 34 } 35 36 using lerp_fn_scalar = void (*)( 37 at::TensorIteratorBase& iter, 38 const Scalar& weight); 39 40 using lerp_fn_tensor = void (*)( 41 at::TensorIteratorBase& iter); 42 43 DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight); 44 DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight); 45 46 } // namespace at::native 47