xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Lerp.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Lerp.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/TensorMeta.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/lerp_native.h>
11 #endif
12 
13 namespace at::meta {
14 
TORCH_META_FUNC(lerp_Tensor)15 TORCH_META_FUNC(lerp_Tensor)(
16     const Tensor& self, const Tensor& end, const Tensor& weight) {
17   TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
18               " for `end` but got dtype ", end.dtype());
19   TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
20               " for `weight` but got dtype ", weight.dtype());
21   build(at::TensorIteratorConfig()
22         .add_output(maybe_get_output())
23         .add_const_input(self)
24         .add_const_input(end)
25         .add_const_input(weight));
26 }
27 
TORCH_META_FUNC(lerp_Scalar)28 TORCH_META_FUNC(lerp_Scalar)(
29     const Tensor& self, const Tensor& end, const Scalar& /*weight*/) {
30   TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
31               " for `end` but got dtype ", end.dtype());
32   build_binary_op(maybe_get_output(), self, end);
33 }
34 
35 }  // namespace at::meta
36 
37 namespace at::native {
38 
TORCH_IMPL_FUNC(lerp_Tensor)39 TORCH_IMPL_FUNC(lerp_Tensor)(
40     const Tensor& /*self*/, const Tensor& /*end*/, const Tensor& weight, const Tensor& /*out*/) {
41   lerp_kernel_tensor_weight(device_type(), *this);
42 }
43 
TORCH_IMPL_FUNC(lerp_Scalar)44 TORCH_IMPL_FUNC(lerp_Scalar)(
45     const Tensor& /*self*/, const Tensor& /*end*/, const Scalar& weight, const Tensor& /*out*/) {
46   lerp_kernel_scalar_weight(device_type(), *this, weight);
47 }
48 
49 DEFINE_DISPATCH(lerp_kernel_scalar_weight);
50 DEFINE_DISPATCH(lerp_kernel_tensor_weight);
51 
52 } // namespace at::native
53