1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #include <ATen/DynamicLibrary.h>
3
4 #include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
5 #include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
6 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
7
8 // TODO: Investigate why this is necessary, but fixes build problems in FRL
9 #if __has_include("filesystem")
10 #include <filesystem>
11 namespace fs = std::filesystem;
12 #else
13 #include <experimental/filesystem>
14 namespace fs = std::experimental::filesystem;
15 #endif
16
17 #ifndef _WIN32
18 #include <sys/stat.h>
19 #endif
20
21 namespace {
file_exists(std::string & path)22 bool file_exists(std::string& path) {
23 #ifdef _WIN32
24 return fs::exists(path);
25 #else
26 struct stat rc;
27 return lstat(path.c_str(), &rc) == 0;
28 #endif
29 }
30 } // namespace
31
32 namespace torch::inductor {
33
AOTIModelContainerRunner(const std::string & model_so_path,size_t num_models,const std::string & device_str,const std::string & cubin_dir)34 AOTIModelContainerRunner::AOTIModelContainerRunner(
35 const std::string& model_so_path,
36 size_t num_models,
37 const std::string& device_str,
38 const std::string& cubin_dir) {
39 model_so_ = std::make_unique<at::DynamicLibrary>(model_so_path.c_str());
40 TORCH_CHECK(model_so_, "Failed to load model: ", model_so_path);
41 create_func_ = reinterpret_cast<decltype(create_func_)>(
42 model_so_->sym("AOTInductorModelContainerCreateWithDevice"));
43 delete_func_ = reinterpret_cast<decltype(delete_func_)>(
44 model_so_->sym("AOTInductorModelContainerDelete"));
45 get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
46 model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
47 run_func_ = reinterpret_cast<decltype(run_func_)>(
48 model_so_->sym("AOTInductorModelContainerRun"));
49 get_num_constants_func_ = reinterpret_cast<decltype(get_num_constants_func_)>(
50 model_so_->sym("AOTInductorModelContainerGetNumConstants"));
51 get_constant_name_func_ = reinterpret_cast<decltype(get_constant_name_func_)>(
52 model_so_->sym("AOTInductorModelContainerGetConstantName"));
53 get_constant_original_fqn_func_ =
54 reinterpret_cast<decltype(get_constant_original_fqn_func_)>(
55 model_so_->sym("AOTInductorModelContainerGetConstantOriginalFQN"));
56 get_constant_dtype_func_ =
57 reinterpret_cast<decltype(get_constant_dtype_func_)>(
58 model_so_->sym("AOTInductorModelContainerGetConstantDtype"));
59 update_constant_buffer_func_ =
60 reinterpret_cast<decltype(update_constant_buffer_func_)>(
61 model_so_->sym("AOTInductorModelContainerUpdateConstantBuffer"));
62 update_inactive_constant_buffer_func_ =
63 reinterpret_cast<decltype(update_inactive_constant_buffer_func_)>(
64 model_so_->sym(
65 "AOTInductorModelContainerUpdateInactiveConstantBuffer"));
66 run_const_fold_func_ = reinterpret_cast<decltype(run_const_fold_func_)>(
67 model_so_->sym("AOTInductorModelContainerRunConstantFolding"));
68 swap_constant_buffer_func_ =
69 reinterpret_cast<decltype(swap_constant_buffer_func_)>(
70 model_so_->sym("AOTInductorModelContainerSwapConstantBuffer"));
71 get_call_spec_func_ = reinterpret_cast<decltype(get_call_spec_func_)>(
72 model_so_->sym("AOTInductorModelContainerGetCallSpec"));
73
74 // Hack to find the json file name from the model so file
75 size_t lastindex = model_so_path.find_last_of('.');
76 std::string json_filename = model_so_path.substr(0, lastindex) + ".json";
77
78 if (file_exists(json_filename)) {
79 proxy_executor_ = std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
80 json_filename, device_str == "cpu");
81 proxy_executor_handle_ =
82 reinterpret_cast<AOTIProxyExecutorHandle>(proxy_executor_.get());
83 }
84
85 AOTI_RUNTIME_ERROR_CODE_CHECK(create_func_(
86 &container_handle_,
87 num_models,
88 device_str.c_str(),
89 cubin_dir.empty() ? nullptr : cubin_dir.c_str()));
90 }
91
~AOTIModelContainerRunner()92 AOTIModelContainerRunner::~AOTIModelContainerRunner() {
93 AOTIRuntimeError result = delete_func_(container_handle_);
94 TORCH_CHECK(
95 result == AOTI_RUNTIME_SUCCESS, "AOTInductorModelContainerDelete failed");
96 }
97
run(std::vector<at::Tensor> & inputs,AOTInductorStreamHandle cuda_stream_handle)98 std::vector<at::Tensor> AOTIModelContainerRunner::run(
99 std::vector<at::Tensor>& inputs,
100 AOTInductorStreamHandle cuda_stream_handle) {
101 auto input_handles =
102 torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
103
104 // For outputs, we only allocate a vector to hold returned tensor handles,
105 // not allocating the actual output tensor storage here
106 size_t num_outputs = 0;
107 AOTI_RUNTIME_ERROR_CODE_CHECK(
108 get_num_outputs_func_(container_handle_, &num_outputs));
109 std::vector<AtenTensorHandle> output_handles(num_outputs);
110
111 AOTI_RUNTIME_ERROR_CODE_CHECK(run_func_(
112 container_handle_,
113 input_handles.data(),
114 input_handles.size(),
115 output_handles.data(),
116 output_handles.size(),
117 cuda_stream_handle,
118 proxy_executor_handle_));
119
120 return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
121 output_handles.data(), output_handles.size());
122 }
123
124 std::unordered_map<std::string, std::string> AOTIModelContainerRunner::
getConstantNamesToOriginalFQNs() const125 getConstantNamesToOriginalFQNs() const {
126 std::unordered_map<std::string, std::string> result;
127 size_t num_constants{0};
128 AOTI_RUNTIME_ERROR_CODE_CHECK(
129 get_num_constants_func_(container_handle_, &num_constants));
130 for (size_t i = 0; i < num_constants; ++i) {
131 const char* name{nullptr};
132 const char* original_fqn{nullptr};
133 AOTI_RUNTIME_ERROR_CODE_CHECK(
134 get_constant_name_func_(container_handle_, i, &name));
135 AOTI_RUNTIME_ERROR_CODE_CHECK(
136 get_constant_original_fqn_func_(container_handle_, i, &original_fqn));
137 result.emplace(name, original_fqn);
138 }
139 return result;
140 }
141
142 std::unordered_map<std::string, int32_t> AOTIModelContainerRunner::
getConstantNamesToDtypes() const143 getConstantNamesToDtypes() const {
144 std::unordered_map<std::string, int32_t> result;
145 size_t num_constants{0};
146 AOTI_RUNTIME_ERROR_CODE_CHECK(
147 get_num_constants_func_(container_handle_, &num_constants));
148 for (size_t i = 0; i < num_constants; ++i) {
149 const char* name{nullptr};
150 int32_t dtype{0};
151 AOTI_RUNTIME_ERROR_CODE_CHECK(
152 get_constant_name_func_(container_handle_, i, &name));
153 AOTI_RUNTIME_ERROR_CODE_CHECK(
154 get_constant_dtype_func_(container_handle_, i, &dtype));
155 result.emplace(name, dtype);
156 }
157 return result;
158 }
159
update_constant_buffer(const TensorConstantMap & const_map,bool use_inactive,bool check_full_update)160 void AOTIModelContainerRunner::update_constant_buffer(
161 const TensorConstantMap& const_map,
162 bool use_inactive,
163 bool check_full_update) {
164 AOTI_RUNTIME_ERROR_CODE_CHECK(update_constant_buffer_func_(
165 container_handle_,
166 (AOTInductorConstantMapHandle)&const_map,
167 use_inactive,
168 check_full_update));
169 }
170
update_inactive_constant_buffer(const TensorConstantMap & const_map)171 void AOTIModelContainerRunner::update_inactive_constant_buffer(
172 const TensorConstantMap& const_map) {
173 AOTI_RUNTIME_ERROR_CODE_CHECK(update_inactive_constant_buffer_func_(
174 container_handle_, (AOTInductorConstantMapHandle)&const_map));
175 }
176
run_const_fold(bool use_inactive,AOTInductorStreamHandle cuda_stream_handle)177 void AOTIModelContainerRunner::run_const_fold(
178 bool use_inactive,
179 AOTInductorStreamHandle cuda_stream_handle) {
180 AOTI_RUNTIME_ERROR_CODE_CHECK(run_const_fold_func_(
181 container_handle_,
182 use_inactive,
183 cuda_stream_handle,
184 proxy_executor_handle_));
185 }
186
swap_constant_buffer()187 void AOTIModelContainerRunner::swap_constant_buffer() {
188 AOTI_RUNTIME_ERROR_CODE_CHECK(swap_constant_buffer_func_(container_handle_));
189 }
190
get_call_spec()191 std::vector<std::string> AOTIModelContainerRunner::get_call_spec() {
192 const char* in_spec = nullptr;
193 const char* out_spec = nullptr;
194 AOTI_RUNTIME_ERROR_CODE_CHECK(
195 get_call_spec_func_(container_handle_, &in_spec, &out_spec));
196 return {in_spec, out_spec};
197 }
198
199 std::unordered_map<std::string, CreateAOTIModelRunnerFunc>&
getAOTIModelRunnerRegistry()200 getAOTIModelRunnerRegistry() {
201 static std::unordered_map<std::string, CreateAOTIModelRunnerFunc>
202 aoti_model_runner_registry_;
203 return aoti_model_runner_registry_;
204 }
205
206 } // namespace torch::inductor
207 #endif
208