1 #pragma once
2
3 #include <torch/csrc/inductor/aoti_runtime/utils.h>
4
5 namespace torch::aot_inductor {
6
7 template <typename T>
scalar_to_tensor_handle(T value)8 inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) {
9 throw std::runtime_error("Unsupported scalar_to_tensor_handle");
10 }
11
12 // Specialize for supported C++ primitive types
13 #define AOTI_RUNTIME_SCALAR_TO_TENSOR(dtype, ctype) \
14 template <> \
15 inline RAIIAtenTensorHandle scalar_to_tensor_handle<ctype>(ctype value) { \
16 AtenTensorHandle tensor_handle; \
17 AOTI_TORCH_ERROR_CODE_CHECK( \
18 aoti_torch_scalar_to_tensor_##dtype(value, &tensor_handle)); \
19 return RAIIAtenTensorHandle(tensor_handle); \
20 }
21
22 AOTI_RUNTIME_SCALAR_TO_TENSOR(float32, float)
23 AOTI_RUNTIME_SCALAR_TO_TENSOR(float64, double)
24 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint8, uint8_t)
25 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint16, uint16_t)
26 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint32, uint32_t)
27 AOTI_RUNTIME_SCALAR_TO_TENSOR(uint64, uint64_t)
28 AOTI_RUNTIME_SCALAR_TO_TENSOR(int8, int8_t)
29 AOTI_RUNTIME_SCALAR_TO_TENSOR(int16, int16_t)
30 AOTI_RUNTIME_SCALAR_TO_TENSOR(int32, int32_t)
31 AOTI_RUNTIME_SCALAR_TO_TENSOR(int64, int64_t)
32 AOTI_RUNTIME_SCALAR_TO_TENSOR(bool, bool)
33 #undef AOTI_RUNTIME_SCALAR_TO_TENSOR
34
35 } // namespace torch::aot_inductor
36