xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Generator.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/core/List.h>
6 #include <c10/core/DeviceType.h>
7 #include <c10/core/SymIntArrayRef.h>
8 #include <c10/util/ArrayRef.h>
9 #include <c10/util/Logging.h>
10 #include <c10/util/OptionalArrayRef.h>
11 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
12 #include <optional>
13 
14 #define AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(...)    \
15   try {                                                    \
16     __VA_ARGS__                                            \
17   } catch (const std::exception& e) {                      \
18     LOG(ERROR) << "Exception in aoti_torch: " << e.what(); \
19     return AOTI_TORCH_FAILURE;                             \
20   } catch (...) {                                          \
21     LOG(ERROR) << "Exception in aoti_torch: UNKNOWN";      \
22     return AOTI_TORCH_FAILURE;                             \
23   }                                                        \
24   return AOTI_TORCH_SUCCESS;
25 
26 namespace torch::aot_inductor {
27 
tensor_handle_to_tensor_pointer(AtenTensorHandle handle)28 inline at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
29   return reinterpret_cast<at::Tensor*>(handle);
30 }
31 
tensor_pointer_to_tensor_handle(at::Tensor * tensor)32 inline AtenTensorHandle tensor_pointer_to_tensor_handle(at::Tensor* tensor) {
33   return reinterpret_cast<AtenTensorHandle>(tensor);
34 }
35 
generator_handle_to_generator_pointer(AtenGeneratorHandle handle)36 inline at::Generator* generator_handle_to_generator_pointer(
37     AtenGeneratorHandle handle) {
38   return reinterpret_cast<at::Generator*>(handle);
39 }
40 
generator_pointer_to_generator_handle(at::Generator * generator)41 inline AtenGeneratorHandle generator_pointer_to_generator_handle(
42     at::Generator* generator) {
43   return reinterpret_cast<AtenGeneratorHandle>(generator);
44 }
45 
new_tensor_handle(at::Tensor && tensor)46 inline AtenTensorHandle new_tensor_handle(at::Tensor&& tensor) {
47   at::Tensor* new_tensor = new at::Tensor(std::move(tensor));
48   return tensor_pointer_to_tensor_handle(new_tensor);
49 }
50 
assert_inf_and_nan(const std::string & tensor_name,at::Tensor & check_tensor)51 inline void assert_inf_and_nan(
52     const std::string& tensor_name,
53     at::Tensor& check_tensor) {
54   auto isnan_tensor = check_tensor.isnan();
55   if (isnan_tensor.any().item<bool>()) {
56     throw std::runtime_error("At least one NaN in " + tensor_name);
57   }
58   auto isinf_tensor = check_tensor.isinf();
59   if (isinf_tensor.any().item<bool>()) {
60     throw std::runtime_error("At least one INF in " + tensor_name);
61   }
62 }
63 
64 // utility functions to convert a pointer to an optional value
65 template <class T>
pointer_to_optional(T * ptr)66 inline std::optional<T> pointer_to_optional(T* ptr) {
67   return ptr ? std::make_optional(*ptr) : std::nullopt;
68 }
69 
70 template <class T, class U, typename = std::enable_if_t<!std::is_same_v<T, U>>>
pointer_to_optional(U * ptr)71 inline std::optional<T> pointer_to_optional(U* ptr) {
72   return ptr ? std::make_optional<T>(T(*ptr)) : std::nullopt;
73 }
74 
75 template <>
pointer_to_optional(AtenTensorHandle * ptr)76 inline std::optional<at::Tensor> pointer_to_optional(AtenTensorHandle* ptr) {
77   return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr))
78              : std::nullopt;
79 }
80 
81 template <>
pointer_to_optional(const AtenTensorHandle * ptr)82 inline std::optional<at::Tensor> pointer_to_optional(
83     const AtenTensorHandle* ptr) {
84   return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr))
85              : std::nullopt;
86 }
87 
88 template <>
pointer_to_optional(AtenGeneratorHandle * ptr)89 inline std::optional<at::Generator> pointer_to_optional(
90     AtenGeneratorHandle* ptr) {
91   return ptr ? std::make_optional(*generator_handle_to_generator_pointer(*ptr))
92              : std::nullopt;
93 }
94 
pointer_to_optional_device(int32_t * device_type,int32_t device_index)95 inline std::optional<c10::Device> pointer_to_optional_device(
96     int32_t* device_type,
97     int32_t device_index) {
98   return device_type ? std::make_optional(c10::Device(
99                            static_cast<c10::DeviceType>(*device_type),
100                            static_cast<c10::DeviceIndex>(device_index)))
101                      : std::nullopt;
102 }
103 
104 // utility functions to convert a pointer to a list
105 template <typename T>
106 struct is_optional : std::false_type {};
107 template <typename T>
108 struct is_optional<std::optional<T>> : std::true_type {};
109 
110 template <class T>
111 inline c10::ArrayRef<T> pointer_to_list(T* ptr, int64_t len) {
112   return c10::ArrayRef<T>(ptr, len);
113 }
114 
115 template <
116     class T,
117     class U,
118     typename = std::enable_if_t<!std::is_same_v<T, U>>,
119     typename = std::enable_if_t<!is_optional<T>::value>>
120 inline std::vector<T> pointer_to_list(U* ptr, int64_t len) {
121   // std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
122   // site
123   std::vector<T> result;
124   result.reserve(len);
125   for (int64_t i = 0; i < len; i++) {
126     result.emplace_back(T(ptr[i]));
127   }
128   return result;
129 }
130 
131 template <class T, class U, typename = std::enable_if_t<is_optional<T>::value>>
132 inline std::vector<T> pointer_to_list(U** ptr, int64_t len) {
133   // Here U** denotes a list of optional arguments
134   // std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
135   // site
136   std::vector<T> result;
137   result.reserve(len);
138   for (int64_t i = 0; i < len; i++) {
139     result.emplace_back(pointer_to_optional(ptr[i]));
140   }
141   return result;
142 }
143 
144 template <>
145 inline std::vector<at::Tensor> pointer_to_list(
146     const AtenTensorHandle* ptr,
147     int64_t len) {
148   std::vector<at::Tensor> result;
149   result.reserve(len);
150   for (int64_t i = 0; i < len; i++) {
151     result.emplace_back(*tensor_handle_to_tensor_pointer(*ptr));
152   }
153   return result;
154 }
155 
156 template <>
157 inline std::vector<std::optional<at::Tensor>> pointer_to_list(
158     const AtenTensorHandle** ptr,
159     int64_t len) {
160   std::vector<std::optional<at::Tensor>> result;
161   result.reserve(len);
162   for (int64_t i = 0; i < len; i++) {
163     result.emplace_back(pointer_to_optional<at::Tensor>(ptr[i]));
164   }
165   return result;
166 }
167 
168 template <int N>
169 inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
170   std::array<bool, N> result;
171   std::copy(ptr, ptr + N, result.begin());
172   return result;
173 }
174 
175 // Utility function to convert a pointer to an optional list of values
176 template <class T, class U>
177 inline std::optional<c10::ArrayRef<T>> pointer_to_optional_list(
178     U** ptr,
179     int64_t len) {
180   return ptr
181       ? std::make_optional<c10::ArrayRef<T>>(pointer_to_list<T>(*ptr, len))
182       : std::nullopt;
183 }
184 
185 } // namespace torch::aot_inductor
186