1 #pragma once
2
3 #include <iostream>
4 #include <memory>
5 #include <sstream>
6 #include <stdexcept>
7 #include <string>
8 #include <vector>
9
10 // WARNING: Be careful when adding new includes here. This header will be used
11 // in model.so, and should not refer to any aten/c10 headers except the stable
12 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
13 // applies to other files under torch/csrc/inductor/aoti_runtime/.
14 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
15
16 #if defined(__GNUC__) || defined(__clang__)
17 #define AOTI_NOINLINE __attribute__((noinline))
18 #elif _MSC_VER
19 #define AOTI_NOINLINE __declspec(noinline)
20 #else
21 #define AOTI_NOINLINE
22 #endif
23
throw_exception(const char * call,const char * file,int64_t line)24 AOTI_NOINLINE static void throw_exception(
25 const char* call,
26 const char* file,
27 int64_t line) {
28 std::stringstream ss;
29 ss << call << " API call failed at " << file << ", line " << line;
30 throw std::runtime_error(ss.str());
31 }
32
33 #define AOTI_TORCH_ERROR_CODE_CHECK(call) \
34 if ((call) != AOTI_TORCH_SUCCESS) { \
35 throw_exception(#call, __FILE__, __LINE__); \
36 }
37
38 using AOTIRuntimeError = int32_t;
39 #define AOTI_RUNTIME_SUCCESS 0
40 #define AOTI_RUNTIME_FAILURE 1
41
42 #define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \
43 if ((call) != AOTI_RUNTIME_SUCCESS) { \
44 throw_exception(#call, __FILE__, __LINE__); \
45 }
46
47 namespace torch::aot_inductor {
48
49 using DeleterFnPtr = void (*)(void*);
50
noop_deleter(void *)51 inline void noop_deleter(void*) {}
52
delete_tensor_object(void * ptr)53 inline void delete_tensor_object(void* ptr) {
54 AOTI_TORCH_ERROR_CODE_CHECK(
55 aoti_torch_delete_tensor_object(reinterpret_cast<AtenTensorHandle>(ptr)));
56 }
57
58 // RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI
59 class RAIIAtenTensorHandle {
60 public:
RAIIAtenTensorHandle()61 RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {}
62 RAIIAtenTensorHandle(const RAIIAtenTensorHandle& other) = delete;
63 RAIIAtenTensorHandle& operator=(const RAIIAtenTensorHandle& other) = delete;
64
65 // Steal the ownership from another RAIIAtenTensorHandle using std::move
66 RAIIAtenTensorHandle(RAIIAtenTensorHandle&& other) = default;
67 RAIIAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) = default;
68
69 // Steal the ownership from raw AtenTensorHandle
RAIIAtenTensorHandle(AtenTensorHandle handle)70 RAIIAtenTensorHandle(AtenTensorHandle handle)
71 : handle_(handle, delete_tensor_object) {}
72
~RAIIAtenTensorHandle()73 ~RAIIAtenTensorHandle() {
74 handle_.reset();
75 }
76
77 // Return a raw AtenTensorHandle to be used by aoti_torch functions
78 // Note: this function does NOT transfer the ownership of the handle
AtenTensorHandle()79 operator AtenTensorHandle() const {
80 return handle_.get();
81 }
82
release()83 AtenTensorHandle release() {
84 return handle_.release();
85 }
86
get()87 AtenTensorHandle get() const {
88 return handle_.get();
89 }
90
reset()91 void reset() {
92 handle_.reset();
93 }
94
size(int64_t d)95 int64_t size(int64_t d) {
96 int64_t size = 0;
97 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size));
98 return size;
99 }
100
stride(int64_t d)101 int64_t stride(int64_t d) {
102 int64_t stride = 0;
103 AOTI_TORCH_ERROR_CODE_CHECK(
104 aoti_torch_get_stride(handle_.get(), d, &stride));
105 return stride;
106 }
107
storage_offset()108 int64_t storage_offset() {
109 int64_t storage_offset = 0;
110 AOTI_TORCH_ERROR_CODE_CHECK(
111 aoti_torch_get_storage_offset(handle_.get(), &storage_offset));
112 return storage_offset;
113 }
114
115 private:
116 std::unique_ptr<AtenTensorOpaque, DeleterFnPtr> handle_;
117 };
118
119 // Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle
steal_from_raw_handles_to_raii_handles(AtenTensorHandle * handles,size_t size)120 inline std::vector<RAIIAtenTensorHandle> steal_from_raw_handles_to_raii_handles(
121 AtenTensorHandle* handles,
122 size_t size) {
123 std::vector<RAIIAtenTensorHandle> result;
124 result.reserve(size);
125 for (size_t i = 0; i < size; i++) {
126 result.emplace_back(handles[i]);
127 handles[i] = nullptr;
128 }
129 return result;
130 }
131
132 class ConstantHandle {
133 public:
134 ConstantHandle() = default;
135
ConstantHandle(AtenTensorHandle handle)136 explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) {
137 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_));
138 }
139
AtenTensorHandle()140 operator AtenTensorHandle() const {
141 return handle_;
142 }
143
tensor()144 AtenTensorHandle tensor() const {
145 return handle_;
146 }
147
get()148 AtenTensorHandle get() const {
149 return handle_;
150 }
151
data_ptr()152 void* data_ptr() const {
153 return data_;
154 }
155
156 private:
157 AtenTensorHandle handle_{};
158 void* data_ = nullptr;
159 };
160
get_data_ptr_wrapper(const ConstantHandle & constant)161 inline void* get_data_ptr_wrapper(const ConstantHandle& constant) {
162 return constant.data_ptr();
163 }
164
unwrap_raii_handle_if_needed(const ConstantHandle & handle)165 inline const ConstantHandle& unwrap_raii_handle_if_needed(
166 const ConstantHandle& handle) {
167 return handle;
168 }
169
170 // Shouldn't be called.
171 inline AtenTensorHandle wrap_with_raii_handle_if_needed(
172 const ConstantHandle& handle) = delete;
173
174 #define CACHE_TORCH_DTYPE(typename) \
175 static auto cached_torch_dtype_##typename = aoti_torch_dtype_##typename()
176
177 #define CACHE_TORCH_DEVICE(device) \
178 static auto cached_torch_device_type_##device = \
179 aoti_torch_device_type_##device()
180
181 #define CACHE_TORCH_LAYOUT(layout) \
182 static auto cached_torch_layout_##layout = aoti_torch_layout_##layout()
183
184 } // namespace torch::aot_inductor
185