xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ops/tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <c10/core/ScalarType.h>
4 
5 namespace at {
6 
7 // These functions are defined in ATen/Utils.cpp.
8 #define TENSOR(T, S)                                                          \
9   TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options);  \
10   inline Tensor tensor(                                                       \
11       std::initializer_list<T> values, const TensorOptions& options) {        \
12     return at::tensor(ArrayRef<T>(values), options);                          \
13   }                                                                           \
14   inline Tensor tensor(T value, const TensorOptions& options) {               \
15     return at::tensor(ArrayRef<T>(value), options);                           \
16   }                                                                           \
17   inline Tensor tensor(ArrayRef<T> values) {                                  \
18     return at::tensor(std::move(values), at::dtype(k##S));                    \
19   }                                                                           \
20   inline Tensor tensor(std::initializer_list<T> values) {                     \
21     return at::tensor(ArrayRef<T>(values));                                   \
22   }                                                                           \
23   inline Tensor tensor(T value) {                                             \
24     return at::tensor(ArrayRef<T>(value));                                    \
25   }
26 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
27 AT_FORALL_COMPLEX_TYPES(TENSOR)
28 #undef TENSOR
29 
30 }  // namespace at
31