1 #pragma once
2
3 #include <ATen/Generator.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/core/List.h>
6 #include <c10/core/DeviceType.h>
7 #include <c10/core/SymIntArrayRef.h>
8 #include <c10/util/ArrayRef.h>
9 #include <c10/util/Logging.h>
10 #include <c10/util/OptionalArrayRef.h>
11 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
12 #include <optional>
13
14 #define AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
15 try { \
16 __VA_ARGS__ \
17 } catch (const std::exception& e) { \
18 LOG(ERROR) << "Exception in aoti_torch: " << e.what(); \
19 return AOTI_TORCH_FAILURE; \
20 } catch (...) { \
21 LOG(ERROR) << "Exception in aoti_torch: UNKNOWN"; \
22 return AOTI_TORCH_FAILURE; \
23 } \
24 return AOTI_TORCH_SUCCESS;
25
26 namespace torch::aot_inductor {
27
tensor_handle_to_tensor_pointer(AtenTensorHandle handle)28 inline at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
29 return reinterpret_cast<at::Tensor*>(handle);
30 }
31
tensor_pointer_to_tensor_handle(at::Tensor * tensor)32 inline AtenTensorHandle tensor_pointer_to_tensor_handle(at::Tensor* tensor) {
33 return reinterpret_cast<AtenTensorHandle>(tensor);
34 }
35
generator_handle_to_generator_pointer(AtenGeneratorHandle handle)36 inline at::Generator* generator_handle_to_generator_pointer(
37 AtenGeneratorHandle handle) {
38 return reinterpret_cast<at::Generator*>(handle);
39 }
40
generator_pointer_to_generator_handle(at::Generator * generator)41 inline AtenGeneratorHandle generator_pointer_to_generator_handle(
42 at::Generator* generator) {
43 return reinterpret_cast<AtenGeneratorHandle>(generator);
44 }
45
new_tensor_handle(at::Tensor && tensor)46 inline AtenTensorHandle new_tensor_handle(at::Tensor&& tensor) {
47 at::Tensor* new_tensor = new at::Tensor(std::move(tensor));
48 return tensor_pointer_to_tensor_handle(new_tensor);
49 }
50
assert_inf_and_nan(const std::string & tensor_name,at::Tensor & check_tensor)51 inline void assert_inf_and_nan(
52 const std::string& tensor_name,
53 at::Tensor& check_tensor) {
54 auto isnan_tensor = check_tensor.isnan();
55 if (isnan_tensor.any().item<bool>()) {
56 throw std::runtime_error("At least one NaN in " + tensor_name);
57 }
58 auto isinf_tensor = check_tensor.isinf();
59 if (isinf_tensor.any().item<bool>()) {
60 throw std::runtime_error("At least one INF in " + tensor_name);
61 }
62 }
63
64 // utility functions to convert a pointer to an optional value
65 template <class T>
pointer_to_optional(T * ptr)66 inline std::optional<T> pointer_to_optional(T* ptr) {
67 return ptr ? std::make_optional(*ptr) : std::nullopt;
68 }
69
70 template <class T, class U, typename = std::enable_if_t<!std::is_same_v<T, U>>>
pointer_to_optional(U * ptr)71 inline std::optional<T> pointer_to_optional(U* ptr) {
72 return ptr ? std::make_optional<T>(T(*ptr)) : std::nullopt;
73 }
74
75 template <>
pointer_to_optional(AtenTensorHandle * ptr)76 inline std::optional<at::Tensor> pointer_to_optional(AtenTensorHandle* ptr) {
77 return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr))
78 : std::nullopt;
79 }
80
81 template <>
pointer_to_optional(const AtenTensorHandle * ptr)82 inline std::optional<at::Tensor> pointer_to_optional(
83 const AtenTensorHandle* ptr) {
84 return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr))
85 : std::nullopt;
86 }
87
88 template <>
pointer_to_optional(AtenGeneratorHandle * ptr)89 inline std::optional<at::Generator> pointer_to_optional(
90 AtenGeneratorHandle* ptr) {
91 return ptr ? std::make_optional(*generator_handle_to_generator_pointer(*ptr))
92 : std::nullopt;
93 }
94
pointer_to_optional_device(int32_t * device_type,int32_t device_index)95 inline std::optional<c10::Device> pointer_to_optional_device(
96 int32_t* device_type,
97 int32_t device_index) {
98 return device_type ? std::make_optional(c10::Device(
99 static_cast<c10::DeviceType>(*device_type),
100 static_cast<c10::DeviceIndex>(device_index)))
101 : std::nullopt;
102 }
103
104 // utility functions to convert a pointer to a list
105 template <typename T>
106 struct is_optional : std::false_type {};
107 template <typename T>
108 struct is_optional<std::optional<T>> : std::true_type {};
109
110 template <class T>
111 inline c10::ArrayRef<T> pointer_to_list(T* ptr, int64_t len) {
112 return c10::ArrayRef<T>(ptr, len);
113 }
114
115 template <
116 class T,
117 class U,
118 typename = std::enable_if_t<!std::is_same_v<T, U>>,
119 typename = std::enable_if_t<!is_optional<T>::value>>
120 inline std::vector<T> pointer_to_list(U* ptr, int64_t len) {
121 // std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
122 // site
123 std::vector<T> result;
124 result.reserve(len);
125 for (int64_t i = 0; i < len; i++) {
126 result.emplace_back(T(ptr[i]));
127 }
128 return result;
129 }
130
131 template <class T, class U, typename = std::enable_if_t<is_optional<T>::value>>
132 inline std::vector<T> pointer_to_list(U** ptr, int64_t len) {
133 // Here U** denotes a list of optional arguments
134 // std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
135 // site
136 std::vector<T> result;
137 result.reserve(len);
138 for (int64_t i = 0; i < len; i++) {
139 result.emplace_back(pointer_to_optional(ptr[i]));
140 }
141 return result;
142 }
143
144 template <>
145 inline std::vector<at::Tensor> pointer_to_list(
146 const AtenTensorHandle* ptr,
147 int64_t len) {
148 std::vector<at::Tensor> result;
149 result.reserve(len);
150 for (int64_t i = 0; i < len; i++) {
151 result.emplace_back(*tensor_handle_to_tensor_pointer(*ptr));
152 }
153 return result;
154 }
155
156 template <>
157 inline std::vector<std::optional<at::Tensor>> pointer_to_list(
158 const AtenTensorHandle** ptr,
159 int64_t len) {
160 std::vector<std::optional<at::Tensor>> result;
161 result.reserve(len);
162 for (int64_t i = 0; i < len; i++) {
163 result.emplace_back(pointer_to_optional<at::Tensor>(ptr[i]));
164 }
165 return result;
166 }
167
168 template <int N>
169 inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
170 std::array<bool, N> result;
171 std::copy(ptr, ptr + N, result.begin());
172 return result;
173 }
174
175 // Utility function to convert a pointer to an optional list of values
176 template <class T, class U>
177 inline std::optional<c10::ArrayRef<T>> pointer_to_optional_list(
178 U** ptr,
179 int64_t len) {
180 return ptr
181 ? std::make_optional<c10::ArrayRef<T>>(pointer_to_list<T>(*ptr, len))
182 : std::nullopt;
183 }
184
185 } // namespace torch::aot_inductor
186