xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runner/model_container_runner.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #include <ATen/DynamicLibrary.h>
3 
4 #include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
5 #include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
6 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
7 
8 // TODO: Investigate why this is necessary, but fixes build problems in FRL
9 #if __has_include("filesystem")
10 #include <filesystem>
11 namespace fs = std::filesystem;
12 #else
13 #include <experimental/filesystem>
14 namespace fs = std::experimental::filesystem;
15 #endif
16 
17 #ifndef _WIN32
18 #include <sys/stat.h>
19 #endif
20 
21 namespace {
file_exists(std::string & path)22 bool file_exists(std::string& path) {
23 #ifdef _WIN32
24   return fs::exists(path);
25 #else
26   struct stat rc;
27   return lstat(path.c_str(), &rc) == 0;
28 #endif
29 }
30 } // namespace
31 
32 namespace torch::inductor {
33 
AOTIModelContainerRunner(const std::string & model_so_path,size_t num_models,const std::string & device_str,const std::string & cubin_dir)34 AOTIModelContainerRunner::AOTIModelContainerRunner(
35     const std::string& model_so_path,
36     size_t num_models,
37     const std::string& device_str,
38     const std::string& cubin_dir) {
39   model_so_ = std::make_unique<at::DynamicLibrary>(model_so_path.c_str());
40   TORCH_CHECK(model_so_, "Failed to load model: ", model_so_path);
41   create_func_ = reinterpret_cast<decltype(create_func_)>(
42       model_so_->sym("AOTInductorModelContainerCreateWithDevice"));
43   delete_func_ = reinterpret_cast<decltype(delete_func_)>(
44       model_so_->sym("AOTInductorModelContainerDelete"));
45   get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
46       model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
47   run_func_ = reinterpret_cast<decltype(run_func_)>(
48       model_so_->sym("AOTInductorModelContainerRun"));
49   get_num_constants_func_ = reinterpret_cast<decltype(get_num_constants_func_)>(
50       model_so_->sym("AOTInductorModelContainerGetNumConstants"));
51   get_constant_name_func_ = reinterpret_cast<decltype(get_constant_name_func_)>(
52       model_so_->sym("AOTInductorModelContainerGetConstantName"));
53   get_constant_original_fqn_func_ =
54       reinterpret_cast<decltype(get_constant_original_fqn_func_)>(
55           model_so_->sym("AOTInductorModelContainerGetConstantOriginalFQN"));
56   get_constant_dtype_func_ =
57       reinterpret_cast<decltype(get_constant_dtype_func_)>(
58           model_so_->sym("AOTInductorModelContainerGetConstantDtype"));
59   update_constant_buffer_func_ =
60       reinterpret_cast<decltype(update_constant_buffer_func_)>(
61           model_so_->sym("AOTInductorModelContainerUpdateConstantBuffer"));
62   update_inactive_constant_buffer_func_ =
63       reinterpret_cast<decltype(update_inactive_constant_buffer_func_)>(
64           model_so_->sym(
65               "AOTInductorModelContainerUpdateInactiveConstantBuffer"));
66   run_const_fold_func_ = reinterpret_cast<decltype(run_const_fold_func_)>(
67       model_so_->sym("AOTInductorModelContainerRunConstantFolding"));
68   swap_constant_buffer_func_ =
69       reinterpret_cast<decltype(swap_constant_buffer_func_)>(
70           model_so_->sym("AOTInductorModelContainerSwapConstantBuffer"));
71   get_call_spec_func_ = reinterpret_cast<decltype(get_call_spec_func_)>(
72       model_so_->sym("AOTInductorModelContainerGetCallSpec"));
73 
74   // Hack to find the json file name from the model so file
75   size_t lastindex = model_so_path.find_last_of('.');
76   std::string json_filename = model_so_path.substr(0, lastindex) + ".json";
77 
78   if (file_exists(json_filename)) {
79     proxy_executor_ = std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
80         json_filename, device_str == "cpu");
81     proxy_executor_handle_ =
82         reinterpret_cast<AOTIProxyExecutorHandle>(proxy_executor_.get());
83   }
84 
85   AOTI_RUNTIME_ERROR_CODE_CHECK(create_func_(
86       &container_handle_,
87       num_models,
88       device_str.c_str(),
89       cubin_dir.empty() ? nullptr : cubin_dir.c_str()));
90 }
91 
~AOTIModelContainerRunner()92 AOTIModelContainerRunner::~AOTIModelContainerRunner() {
93   AOTIRuntimeError result = delete_func_(container_handle_);
94   TORCH_CHECK(
95       result == AOTI_RUNTIME_SUCCESS, "AOTInductorModelContainerDelete failed");
96 }
97 
run(std::vector<at::Tensor> & inputs,AOTInductorStreamHandle cuda_stream_handle)98 std::vector<at::Tensor> AOTIModelContainerRunner::run(
99     std::vector<at::Tensor>& inputs,
100     AOTInductorStreamHandle cuda_stream_handle) {
101   auto input_handles =
102       torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
103 
104   // For outputs, we only allocate a vector to hold returned tensor handles,
105   // not allocating the actual output tensor storage here
106   size_t num_outputs = 0;
107   AOTI_RUNTIME_ERROR_CODE_CHECK(
108       get_num_outputs_func_(container_handle_, &num_outputs));
109   std::vector<AtenTensorHandle> output_handles(num_outputs);
110 
111   AOTI_RUNTIME_ERROR_CODE_CHECK(run_func_(
112       container_handle_,
113       input_handles.data(),
114       input_handles.size(),
115       output_handles.data(),
116       output_handles.size(),
117       cuda_stream_handle,
118       proxy_executor_handle_));
119 
120   return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
121       output_handles.data(), output_handles.size());
122 }
123 
124 std::unordered_map<std::string, std::string> AOTIModelContainerRunner::
getConstantNamesToOriginalFQNs() const125     getConstantNamesToOriginalFQNs() const {
126   std::unordered_map<std::string, std::string> result;
127   size_t num_constants{0};
128   AOTI_RUNTIME_ERROR_CODE_CHECK(
129       get_num_constants_func_(container_handle_, &num_constants));
130   for (size_t i = 0; i < num_constants; ++i) {
131     const char* name{nullptr};
132     const char* original_fqn{nullptr};
133     AOTI_RUNTIME_ERROR_CODE_CHECK(
134         get_constant_name_func_(container_handle_, i, &name));
135     AOTI_RUNTIME_ERROR_CODE_CHECK(
136         get_constant_original_fqn_func_(container_handle_, i, &original_fqn));
137     result.emplace(name, original_fqn);
138   }
139   return result;
140 }
141 
142 std::unordered_map<std::string, int32_t> AOTIModelContainerRunner::
getConstantNamesToDtypes() const143     getConstantNamesToDtypes() const {
144   std::unordered_map<std::string, int32_t> result;
145   size_t num_constants{0};
146   AOTI_RUNTIME_ERROR_CODE_CHECK(
147       get_num_constants_func_(container_handle_, &num_constants));
148   for (size_t i = 0; i < num_constants; ++i) {
149     const char* name{nullptr};
150     int32_t dtype{0};
151     AOTI_RUNTIME_ERROR_CODE_CHECK(
152         get_constant_name_func_(container_handle_, i, &name));
153     AOTI_RUNTIME_ERROR_CODE_CHECK(
154         get_constant_dtype_func_(container_handle_, i, &dtype));
155     result.emplace(name, dtype);
156   }
157   return result;
158 }
159 
update_constant_buffer(const TensorConstantMap & const_map,bool use_inactive,bool check_full_update)160 void AOTIModelContainerRunner::update_constant_buffer(
161     const TensorConstantMap& const_map,
162     bool use_inactive,
163     bool check_full_update) {
164   AOTI_RUNTIME_ERROR_CODE_CHECK(update_constant_buffer_func_(
165       container_handle_,
166       (AOTInductorConstantMapHandle)&const_map,
167       use_inactive,
168       check_full_update));
169 }
170 
update_inactive_constant_buffer(const TensorConstantMap & const_map)171 void AOTIModelContainerRunner::update_inactive_constant_buffer(
172     const TensorConstantMap& const_map) {
173   AOTI_RUNTIME_ERROR_CODE_CHECK(update_inactive_constant_buffer_func_(
174       container_handle_, (AOTInductorConstantMapHandle)&const_map));
175 }
176 
run_const_fold(bool use_inactive,AOTInductorStreamHandle cuda_stream_handle)177 void AOTIModelContainerRunner::run_const_fold(
178     bool use_inactive,
179     AOTInductorStreamHandle cuda_stream_handle) {
180   AOTI_RUNTIME_ERROR_CODE_CHECK(run_const_fold_func_(
181       container_handle_,
182       use_inactive,
183       cuda_stream_handle,
184       proxy_executor_handle_));
185 }
186 
swap_constant_buffer()187 void AOTIModelContainerRunner::swap_constant_buffer() {
188   AOTI_RUNTIME_ERROR_CODE_CHECK(swap_constant_buffer_func_(container_handle_));
189 }
190 
get_call_spec()191 std::vector<std::string> AOTIModelContainerRunner::get_call_spec() {
192   const char* in_spec = nullptr;
193   const char* out_spec = nullptr;
194   AOTI_RUNTIME_ERROR_CODE_CHECK(
195       get_call_spec_func_(container_handle_, &in_spec, &out_spec));
196   return {in_spec, out_spec};
197 }
198 
199 std::unordered_map<std::string, CreateAOTIModelRunnerFunc>&
getAOTIModelRunnerRegistry()200 getAOTIModelRunnerRegistry() {
201   static std::unordered_map<std::string, CreateAOTIModelRunnerFunc>
202       aoti_model_runner_registry_;
203   return aoti_model_runner_registry_;
204 }
205 
206 } // namespace torch::inductor
207 #endif
208