1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Lerp.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/TensorMeta.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/lerp_native.h>
11 #endif
12
13 namespace at::meta {
14
TORCH_META_FUNC(lerp_Tensor)15 TORCH_META_FUNC(lerp_Tensor)(
16 const Tensor& self, const Tensor& end, const Tensor& weight) {
17 TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
18 " for `end` but got dtype ", end.dtype());
19 TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
20 " for `weight` but got dtype ", weight.dtype());
21 build(at::TensorIteratorConfig()
22 .add_output(maybe_get_output())
23 .add_const_input(self)
24 .add_const_input(end)
25 .add_const_input(weight));
26 }
27
TORCH_META_FUNC(lerp_Scalar)28 TORCH_META_FUNC(lerp_Scalar)(
29 const Tensor& self, const Tensor& end, const Scalar& /*weight*/) {
30 TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
31 " for `end` but got dtype ", end.dtype());
32 build_binary_op(maybe_get_output(), self, end);
33 }
34
35 } // namespace at::meta
36
37 namespace at::native {
38
TORCH_IMPL_FUNC(lerp_Tensor)39 TORCH_IMPL_FUNC(lerp_Tensor)(
40 const Tensor& /*self*/, const Tensor& /*end*/, const Tensor& weight, const Tensor& /*out*/) {
41 lerp_kernel_tensor_weight(device_type(), *this);
42 }
43
TORCH_IMPL_FUNC(lerp_Scalar)44 TORCH_IMPL_FUNC(lerp_Scalar)(
45 const Tensor& /*self*/, const Tensor& /*end*/, const Scalar& weight, const Tensor& /*out*/) {
46 lerp_kernel_scalar_weight(device_type(), *this, weight);
47 }
48
49 DEFINE_DISPATCH(lerp_kernel_scalar_weight);
50 DEFINE_DISPATCH(lerp_kernel_tensor_weight);
51
52 } // namespace at::native
53