xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/inductor/aoti_runtime/utils.h>
4 
5 namespace torch::aot_inductor {
6 
7 template <typename T>
scalar_to_tensor_handle(T value)8 inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) {
9   throw std::runtime_error("Unsupported scalar_to_tensor_handle");
10 }
11 
12 // Specialize for supported C++ primitive types
13 #define AOTI_RUNTIME_SCALAR_TO_TENSOR(dtype, ctype)                         \
14   template <>                                                               \
15   inline RAIIAtenTensorHandle scalar_to_tensor_handle<ctype>(ctype value) { \
16     AtenTensorHandle tensor_handle;                                         \
17     AOTI_TORCH_ERROR_CODE_CHECK(                                            \
18         aoti_torch_scalar_to_tensor_##dtype(value, &tensor_handle));        \
19     return RAIIAtenTensorHandle(tensor_handle);                             \
20   }
21 
22 AOTI_RUNTIME_SCALAR_TO_TENSOR(float32, float)
23 AOTI_RUNTIME_SCALAR_TO_TENSOR(float64, double)
24 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint8, uint8_t)
25 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint16, uint16_t)
26 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint32, uint32_t)
27 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint64, uint64_t)
28 AOTI_RUNTIME_SCALAR_TO_TENSOR(int8, int8_t)
29 AOTI_RUNTIME_SCALAR_TO_TENSOR(int16, int16_t)
30 AOTI_RUNTIME_SCALAR_TO_TENSOR(int32, int32_t)
31 AOTI_RUNTIME_SCALAR_TO_TENSOR(int64, int64_t)
32 AOTI_RUNTIME_SCALAR_TO_TENSOR(bool, bool)
33 #undef AOTI_RUNTIME_SCALAR_TO_TENSOR
34 
35 } // namespace torch::aot_inductor
36