xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ScalarOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/Scalar.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/scalar_tensor.h>
10 #endif
11 
12 namespace at::detail {
13 // When filling a number to 1-element CPU tensor, we want to skip
14 // everything but manipulate data ptr directly.
15 // Ideally this fast pass should be implemented in TensorIterator,
16 // but we also want to skip compute_types which in not avoidable
17 // in TensorIterator for now.
18 Tensor& scalar_fill(Tensor& self, const Scalar& value);
19 TORCH_API Tensor scalar_tensor_static(
20     const Scalar& s,
21     std::optional<ScalarType> dtype_opt,
22     std::optional<Device> device_opt);
23 } // namespace at::detail
24 
25 // This is in the c10 namespace because we use ADL to find the functions in it.
26 namespace c10 {
27 
28 // FIXME: this should be (and was) Scalar::toTensor, but there is currently no
29 // way to implement this without going through Derived Types (which are not part
30 // of core).
31 inline at::Tensor scalar_to_tensor(
32     const Scalar& s,
33     const Device device = at::kCPU) {
34   // This is the fast track we have for CPU scalar tensors.
35   if (device == at::kCPU) {
36     return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
37   }
38   return at::scalar_tensor(s, at::device(device).dtype(s.type()));
39 }
40 
41 } // namespace c10
42 
43 namespace at::native {
44 
45 inline Tensor wrapped_scalar_tensor(
46     const Scalar& scalar,
47     const Device device = at::kCPU) {
48   auto tensor = scalar_to_tensor(scalar, device);
49   tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
50   return tensor;
51 }
52 
53 } // namespace at::native
54