xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ScalarOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Dispatch_v2.h>
4 #include <ATen/EmptyTensor.h>
5 #include <ATen/ScalarOps.h>
6 
7 namespace at {
8 namespace {
9 template <typename scalar_t>
fill_inplace(Tensor & self,const Scalar & value_scalar)10 inline void fill_inplace(Tensor& self, const Scalar& value_scalar) {
11   auto value = value_scalar.to<scalar_t>();
12   scalar_t* dptr = static_cast<scalar_t*>(self.data_ptr());
13   *dptr = value;
14 }
15 }
16 
17 namespace detail {
scalar_fill(Tensor & self,const Scalar & value)18 Tensor& scalar_fill(Tensor& self, const Scalar& value) {
19   AT_DISPATCH_V2(
20       self.scalar_type(), "fill_out", AT_WRAP([&]() {
21         fill_inplace<scalar_t>(self, value);
22       }), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
23   return self;
24 }
25 
scalar_tensor_static(const Scalar & s,std::optional<ScalarType> dtype_opt,std::optional<Device> device_opt)26 Tensor scalar_tensor_static(const Scalar& s, std::optional<ScalarType> dtype_opt, std::optional<Device> device_opt) {
27   at::tracer::impl::NoTracerDispatchMode tracer_guard;
28   at::AutoDispatchBelowAutograd mode;
29   Tensor result = at::detail::empty_cpu(
30       {}, dtype_opt, std::nullopt, device_opt, std::nullopt, std::nullopt);
31   scalar_fill(result, s);
32   return result;
33 }
34 } // namespace detail
35 } // namespace at
36