1 #pragma once 2 #include <ATen/core/Tensor.h> 3 #include <c10/core/ScalarType.h> 4 5 namespace at { 6 7 // These functions are defined in ATen/Utils.cpp. 8 #define TENSOR(T, S) \ 9 TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \ 10 inline Tensor tensor( \ 11 std::initializer_list<T> values, const TensorOptions& options) { \ 12 return at::tensor(ArrayRef<T>(values), options); \ 13 } \ 14 inline Tensor tensor(T value, const TensorOptions& options) { \ 15 return at::tensor(ArrayRef<T>(value), options); \ 16 } \ 17 inline Tensor tensor(ArrayRef<T> values) { \ 18 return at::tensor(std::move(values), at::dtype(k##S)); \ 19 } \ 20 inline Tensor tensor(std::initializer_list<T> values) { \ 21 return at::tensor(ArrayRef<T>(values)); \ 22 } \ 23 inline Tensor tensor(T value) { \ 24 return at::tensor(ArrayRef<T>(value)); \ 25 } 26 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) 27 AT_FORALL_COMPLEX_TYPES(TENSOR) 28 #undef TENSOR 29 30 } // namespace at 31