xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runtime/model.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <dlfcn.h>
4 #include <fcntl.h>
5 #include <sys/mman.h>
6 #include <unistd.h>
7 #include <optional>
8 #include <regex>
9 #include <stdexcept>
10 #include <unordered_map>
11 #include <utility>
12 
13 // WARNING: Be careful when adding new includes here. This header will be used
14 // in model.so, and should not refer to any aten/c10 headers except the stable
15 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
16 // applies to other files under torch/csrc/inductor/aoti_runtime/.
17 #include <torch/csrc/inductor/aoti_runtime/device_utils.h>
18 #include <torch/csrc/inductor/aoti_runtime/utils.h>
19 
20 #define AOTI_RUNTIME_CHECK(EXPR, MSG) \
21   do {                                \
22     bool ok = EXPR;                   \
23     if (!ok) {                        \
24       throw std::runtime_error(MSG);  \
25     }                                 \
26   } while (0)
27 
28 // At codegen time, we write out a binary file called constants.bin.
29 // We then turn the raw binary to an object file that exposes this
30 // symbol and link it into the final .so.
31 // For information on the binary format, see `man objcopy`, under
32 // the "binary-architecture" flag:
33 // https://man7.org/linux/man-pages/man1/objcopy.1.html
34 // todo: use #embed in C++ 23 once available
35 // The constants are NOT readonly because they may be mutated.
36 extern uint8_t _binary_constants_bin_start[];
37 extern uint8_t _binary_constants_bin_end[];
38 
39 #define AOTI_CONST_GPU_ALIGNMENT 64
40 
41 namespace {
42 
43 #ifdef USE_CUDA
44 
45 using CUDAPtr = std::unique_ptr<void, std::function<void(void*)>>;
46 
RAII_cudaMalloc(size_t num_bytes)47 CUDAPtr RAII_cudaMalloc(size_t num_bytes) {
48   void* data_ptr;
49   AOTI_RUNTIME_DEVICE_CHECK(cudaMalloc((void**)&data_ptr, num_bytes));
50   auto deleter = [](void* ptr) { AOTI_RUNTIME_DEVICE_CHECK(cudaFree(ptr)); };
51   return CUDAPtr(data_ptr, deleter);
52 }
53 
54 #endif // USE_CUDA
55 
56 } // anonymous namespace
57 
58 namespace torch::aot_inductor {
59 using ConstantMap = std::unordered_map<std::string, RAIIAtenTensorHandle>;
60 
61 // valid device strs are: cpu, cuda, cuda:0, cuda:1, ...
62 // Update the list here if more devices are supported in the future
parse_device_str(const std::string & device_str,int32_t & device_type,int32_t & device_idx)63 inline void parse_device_str(
64     const std::string& device_str,
65     int32_t& device_type,
66     int32_t& device_idx) {
67   std::regex re("(cpu|cuda)(:([0-9]+))?");
68   std::smatch sm;
69   bool matched = std::regex_match(device_str, sm, re);
70   AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str);
71 
72   if (sm[1].str() == "cpu") {
73     device_type = aoti_torch_device_type_cpu();
74   } else if (sm[1].str() == "cuda") {
75     device_type = aoti_torch_device_type_cuda();
76   } else {
77     AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str);
78   }
79 
80   if (sm[3].matched) {
81     device_idx = stoi(sm[3].str());
82   } else {
83     device_idx = -1;
84   }
85 }
86 
87 // Defines the base class for AOTInductorModel, which is generated by the
88 // AOTInductor cpp codegen. Since we do not need dynamic dispatch, we rely
89 // on curiously recurring template pattern (CRTP) to save some runtime
90 // v-table overhead. The generated AOTInductorModel is specialized with
91 // methods such as run_impl.
92 template <typename Model>
93 class AOTInductorModelBase {
94  public:
AOTInductorModelBase(size_t num_inputs,size_t num_outputs,size_t num_constants,const std::string & device_str,std::optional<std::string> cubin_dir)95   AOTInductorModelBase(
96       size_t num_inputs,
97       size_t num_outputs,
98       size_t num_constants,
99       const std::string& device_str,
100       std::optional<std::string> cubin_dir)
101       : inputs_info_(num_inputs),
102         outputs_info_(num_outputs),
103         constants_info_(num_constants),
104         cubin_dir_(std::move(cubin_dir)) {
105     parse_device_str(device_str, device_type_, device_idx_);
106 
107 #ifdef USE_CUDA
108     if (device_idx_ == -1) {
109       AOTI_RUNTIME_DEVICE_CHECK(cudaGetDevice(&device_idx_));
110     }
111 #endif // USE_CUDA
112   }
113 
~AOTInductorModelBase()114   ~AOTInductorModelBase() {
115 #ifdef USE_CUDA
116     if (run_finished_) {
117       auto code = cudaEventDestroy(*run_finished_);
118       if (code != cudaSuccess) {
119         std::cerr << "Failed to destroy CUDA event in AOTInductor model: "
120                   << cudaGetErrorString(code) << std::endl;
121       }
122     }
123 #endif // USE_CUDA
124   }
125 
126   AOTInductorModelBase(AOTInductorModelBase&&) = delete;
127   AOTInductorModelBase& operator=(AOTInductorModelBase&&) = delete;
128   AOTInductorModelBase(const AOTInductorModelBase&) = delete;
129   AOTInductorModelBase& operator=(const AOTInductorModelBase&) = delete;
130 
run(AtenTensorHandle * input_handles,AtenTensorHandle * output_handles,DeviceStreamType stream,AOTIProxyExecutorHandle proxy_executor)131   void run(
132       AtenTensorHandle*
133           input_handles, // array of input AtenTensorHandle; handles
134                          // are stolen; the array itself is borrowed
135       AtenTensorHandle*
136           output_handles, // array for writing output AtenTensorHandle; handles
137                           // will be stolen by the caller; the array itself is
138                           // borrowed
139       DeviceStreamType stream,
140       AOTIProxyExecutorHandle proxy_executor) {
141 #ifdef USE_CUDA
142     if (!run_finished_) {
143       cudaEvent_t run_finished;
144       AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished));
145       run_finished_.emplace(run_finished);
146     }
147 
148     auto* model = static_cast<Model*>(this);
149     model->run_impl(input_handles, output_handles, stream, proxy_executor);
150     AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream));
151 #else // !USE_CUDA
152     run_finished_ = false;
153     auto* model = static_cast<Model*>(this);
154     model->run_impl(input_handles, output_handles, stream, proxy_executor);
155     run_finished_ = true;
156 #endif // USE_CUDA
157   }
158 
159   std::unordered_map<std::string, AtenTensorHandle> run_const_fold(
160       DeviceStreamType stream,
161       AOTIProxyExecutorHandle proxy_executor,
162       bool initialization = false) {
163 #ifdef USE_CUDA
164     if (!run_finished_) {
165       cudaEvent_t run_finished;
166       AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished));
167       run_finished_.emplace(run_finished);
168     }
169 #else // USE_CUDA
170     run_finished_ = false;
171 #endif // USE_CUDA
172 
173     auto* model = static_cast<Model*>(this);
174     auto folded_constants =
175         model->const_run_impl(stream, proxy_executor, initialization);
176 
177 #ifdef USE_CUDA
178     AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream));
179 #else // USE_CUDA
180     run_finished_ = true;
181 #endif // USE_CUDA
182 
183     return folded_constants;
184   }
185 
load_constants()186   void load_constants() {
187     size_t num_constants = this->num_constants();
188     constants_map_->reserve(num_constants);
189 
190     std::vector<size_t> constants_internal_offset(num_constants);
191     if (device_type_ != aoti_torch_device_type_cpu()) {
192       size_t blob_size = 0;
193       compute_cuda_constant_blob(blob_size, constants_internal_offset);
194 #ifdef USE_CUDA
195       constant_blob_ = RAII_cudaMalloc(blob_size);
196 #endif
197     }
198 
199     size_t bytes_read = 0;
200     for (size_t i = 0; i < num_constants; i++) {
201       bool from_folded = this->constant_from_folded(i);
202 #ifndef USE_CUDA
203       if (from_folded) {
204         // We do not reallocate and copy for CPU.
205         continue;
206       }
207 #endif // USE_CUDA
208       std::string name = this->constant_name(i);
209       size_t data_size = this->constant_data_size(i);
210       uint8_t* internal_ptr = (data_size != 0)
211           ? constant_ptr(
212                 constants_internal_offset[i],
213                 bytes_read,
214                 data_size,
215                 from_folded)
216           : nullptr;
217       bytes_read += data_size;
218 
219       // Create at::Tensor from copied memory.
220       auto dtype = this->constant_dtype(i);
221       auto ndim = this->constant_ndim(i);
222       auto size = this->constant_shape(i);
223       auto stride = this->constant_stride(i);
224       auto offset = this->constant_offset(i);
225       auto layout = this->constant_layout(i);
226       auto opaque_metadata_ptr = this->opaque_metadata(i);
227       auto opaque_metadata_size = this->opaque_metadata_size(i);
228 
229       AtenTensorHandle tensor_handle = nullptr;
230 #ifdef AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
231       // When opaque_metadata_size is not 0, we need to have the
232       // aoti_torch_create_tensor_from_blob_v2 available
233       AOTI_RUNTIME_CHECK(
234           opaque_metadata_size == 0,
235           "Expect opaque_metadata_size to be 0 when AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 is defined");
236       AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
237           internal_ptr,
238           ndim,
239           size,
240           stride,
241           offset,
242           dtype,
243           device_type_,
244           device_idx_,
245           &tensor_handle));
246 #else
247       AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2(
248           internal_ptr,
249           ndim,
250           size,
251           stride,
252           offset,
253           dtype,
254           device_type_,
255           device_idx_,
256           &tensor_handle,
257           layout,
258           opaque_metadata_ptr,
259           opaque_metadata_size));
260 #endif // AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
261       constants_map_->emplace(std::move(name), tensor_handle);
262     }
263     if (constants_map_) {
264       this->update_constants_array_from_map();
265     }
266   }
267 
268 #ifdef USE_CUDA
release_constant_blob()269   CUDAPtr&& release_constant_blob() {
270     return std::move(constant_blob_);
271   }
272 #endif
273 
get_constants_array()274   std::shared_ptr<std::vector<ConstantHandle>> get_constants_array() {
275     return constants_;
276   }
277 
get_device_idx()278   int32_t get_device_idx() const {
279     return device_idx_;
280   }
281 
constant_ptr(size_t constant_offset,size_t bytes_read,size_t data_size,bool skip_copy)282   uint8_t* constant_ptr(
283       size_t constant_offset,
284       size_t bytes_read,
285       size_t data_size,
286       bool skip_copy) {
287 #ifdef USE_CUDA
288     auto* constants_ptr = static_cast<uint8_t*>(constant_blob_.get());
289     uint8_t* internal_ptr = constants_ptr + constant_offset;
290     // Copy data to GPU memory
291     // TODO: Handle shared storage case.
292     if (!skip_copy) {
293       AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
294           internal_ptr,
295           _get_constants_start() + bytes_read,
296           data_size,
297           cudaMemcpyHostToDevice));
298     }
299     return internal_ptr;
300 
301 #else
302     // get pointer to constant which is packed in model during compile time.
303     AOTI_RUNTIME_CHECK(!skip_copy, "pure cpu mode doesn't support skip copy");
304     return _get_constants_start() + bytes_read;
305 #endif // USE_CUDA
306   }
307 
compute_cuda_constant_blob(size_t & blob_size,std::vector<size_t> & constants_internal_offset)308   void compute_cuda_constant_blob(
309       size_t& blob_size,
310       std::vector<size_t>& constants_internal_offset) {
311 #ifdef USE_CUDA
312     size_t num_constants = this->num_constants();
313     // Compute required blob size with 64-alignment if on GPU.
314     blob_size = 0;
315     for (size_t i = 0; i < num_constants; i++) {
316       size_t data_size = this->constant_data_size(i);
317       if (data_size % AOTI_CONST_GPU_ALIGNMENT) {
318         data_size = AOTI_CONST_GPU_ALIGNMENT +
319             (data_size / AOTI_CONST_GPU_ALIGNMENT) * AOTI_CONST_GPU_ALIGNMENT;
320       }
321       constants_internal_offset[i] = blob_size;
322       blob_size += data_size;
323     }
324 #endif // USE_CUDA
325   }
326 
num_inputs()327   size_t num_inputs() const {
328     return inputs_info_.size();
329   }
330 
num_outputs()331   size_t num_outputs() const {
332     return outputs_info_.size();
333   }
334 
num_constants()335   size_t num_constants() const {
336     return constants_info_.size();
337   }
338 
input_name(int64_t idx)339   const char* input_name(int64_t idx) const {
340     return inputs_info_.at(idx).name;
341   }
342 
output_name(int64_t idx)343   const char* output_name(int64_t idx) const {
344     return outputs_info_.at(idx).name;
345   }
346 
constant_name(int64_t idx)347   const char* constant_name(int64_t idx) const {
348     return constants_info_.at(idx).name;
349   }
350 
constant_ndim(int64_t idx)351   size_t constant_ndim(int64_t idx) {
352     return constants_info_.at(idx).shape.size();
353   }
354 
constant_shape(int64_t idx)355   const int64_t* constant_shape(int64_t idx) const {
356     return constants_info_.at(idx).shape.data();
357   }
358 
constant_stride(int64_t idx)359   const int64_t* constant_stride(int64_t idx) const {
360     return constants_info_.at(idx).stride.data();
361   }
362 
constant_dtype(int64_t idx)363   int32_t constant_dtype(int64_t idx) const {
364     return constants_info_.at(idx).dtype;
365   }
366 
constant_layout(int64_t idx)367   int32_t constant_layout(int64_t idx) const {
368     return constants_info_.at(idx).layout;
369   }
370 
constant_offset(int64_t idx)371   size_t constant_offset(int64_t idx) const {
372     return constants_info_.at(idx).offset;
373   }
374 
constant_data_size(int64_t idx)375   size_t constant_data_size(int64_t idx) const {
376     return constants_info_.at(idx).data_size;
377   }
378 
constant_original_fqn(int64_t idx)379   const char* constant_original_fqn(int64_t idx) const {
380     return constants_info_.at(idx).original_fqn;
381   }
382 
opaque_metadata(int64_t idx)383   const uint8_t* opaque_metadata(int64_t idx) const {
384     return constants_info_.at(idx).opaque_metadata.data();
385   }
386 
opaque_metadata_size(int64_t idx)387   size_t opaque_metadata_size(int64_t idx) {
388     return constants_info_.at(idx).opaque_metadata.size();
389   }
390 
constant_from_folded(int64_t idx)391   bool constant_from_folded(int64_t idx) const {
392     return constants_info_.at(idx).from_folded;
393   }
394 
get_in_spec()395   const char* get_in_spec() const {
396     return in_spec_.c_str();
397   }
398 
get_out_spec()399   const char* get_out_spec() const {
400     return out_spec_.c_str();
401   }
402 
update_constants_array_from_map()403   void update_constants_array_from_map() {
404     if (!constants_map_) {
405       throw std::runtime_error{
406           "constants_map_ was not ready when constants_ is trying to be constructed from it!"};
407     }
408     if (!constants_) {
409       constants_ =
410           std::make_shared<std::vector<ConstantHandle>>(constants_info_.size());
411     } else {
412       constants_->resize(constants_info_.size());
413     }
414     int idx = 0;
415     for (const auto& info : constants_info_) {
416       const auto it = constants_map_->find(info.name);
417       if (it != constants_map_->end()) {
418         constants_->at(idx) = ConstantHandle(it->second);
419       }
420       idx++;
421     }
422   }
423 
424   void update_constants_map(
425       std::shared_ptr<ConstantMap> constants_map,
426       bool remap_constants_array = true) {
427     constants_map_ = std::move(constants_map);
428     if (remap_constants_array) {
429       update_constants_array_from_map();
430     }
431   }
432 
433   // This function allows us to update the constants_ that is used to look up
434   // the corresponding constant tensor during runtime.
update_constants_array(std::shared_ptr<std::vector<ConstantHandle>> constants_array)435   void update_constants_array(
436       std::shared_ptr<std::vector<ConstantHandle>> constants_array) {
437     constants_ = std::move(constants_array);
438   }
439 
440   /// Returns true if the model is complete.
is_finished()441   bool is_finished() {
442 #ifdef USE_CUDA
443     if (!run_finished_) {
444       throw std::runtime_error{"Model CUDA event was not initialized"};
445     }
446 
447     auto event_status = cudaEventQuery(*run_finished_);
448     if (event_status == cudaSuccess) {
449       return true;
450     } else if (event_status == cudaErrorNotReady) {
451       return false;
452     }
453 
454     throw std::runtime_error(
455         std::string("The model did not finish successfully. Error: ") +
456         cudaGetErrorString(cudaGetLastError()));
457 #else // !USE_CUDA
458     return run_finished_;
459 #endif // USE_CUDA
460   }
461 
462   /// Synchronizes completion event.
wait_for_completion()463   void wait_for_completion() {
464 #ifdef USE_CUDA
465     if (!run_finished_) {
466       throw std::runtime_error{"Model event was not initialized"};
467     }
468 
469     AOTI_RUNTIME_DEVICE_CHECK(cudaEventSynchronize(*run_finished_));
470 #endif // USE_CUDA
471   }
472 
473  protected:
_get_constants_start()474   uint8_t* _get_constants_start() {
475 #ifndef USE_MMAP_SELF
476     return const_cast<uint8_t*>(_binary_constants_bin_start);
477 #else
478     if (self_mmap) {
479       return self_mmap;
480     }
481     Dl_info dl_info;
482     // get pointer to constant which are appended to the binary
483     AOTI_RUNTIME_CHECK(
484         dladdr(__func__, &dl_info), "Can't find shared library name");
485     int fd = open(dl_info.dli_fname, O_RDONLY);
486     AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened");
487     auto fsize = lseek(fd, 0, SEEK_END);
488     auto weights_size =
489         reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[0];
490     auto magic_number =
491         reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[1];
492     auto weights_offset = fsize - weights_size;
493     AOTI_RUNTIME_CHECK(
494         (weights_offset & 0x3fff) == 0,
495         "weights_offset must be aligned to 16K boundary");
496     auto ptr = mmap(
497         NULL,
498         weights_size,
499         PROT_READ | PROT_WRITE,
500         MAP_PRIVATE,
501         fd,
502         weights_offset);
503     close(fd);
504     AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed");
505     self_mmap = static_cast<uint8_t*>(ptr);
506     AOTI_RUNTIME_CHECK(
507         reinterpret_cast<uint64_t*>(
508             self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number,
509         "Weigths data seems corrupt");
510     return self_mmap;
511 #endif
512   }
513   struct ParamInfo {
514     const char* name = nullptr;
515   };
516 
517   struct ConstInfo {
518     const char* name = nullptr;
519     std::vector<int64_t> shape;
520     std::vector<int64_t> stride;
521     int32_t dtype{};
522     int64_t offset{};
523     size_t data_size{};
524     int32_t layout{};
525     std::vector<uint8_t> opaque_metadata;
526     int64_t opaque_metadata_size{};
527     const char* original_fqn = nullptr;
528     bool from_folded{};
529   };
530 
531   std::vector<ParamInfo> inputs_info_;
532   std::vector<ParamInfo> outputs_info_;
533   std::vector<ConstInfo> constants_info_;
534   std::string in_spec_;
535   std::string out_spec_;
536 
537   std::shared_ptr<ConstantMap> constants_map_;
538   std::shared_ptr<std::vector<ConstantHandle>> constants_;
539 
540 #ifdef USE_CUDA
541   // Holds the blob storage for constants' at::Tensor for CUDA.
542   CUDAPtr constant_blob_;
543 #endif // USE_CUDA
544 #ifdef USE_MMAP_SELF
545   uint8_t* self_mmap = NULL;
546 #endif
547 
548   // A directory with CUDA binary files, e.g. compiled kernels, etc.
549   const std::optional<std::string> cubin_dir_;
550 
551   // Record if the model finishes an inference run so that its owning
552   // AOTModelContainer can re-use this instance.
553 #ifdef USE_CUDA
554   std::optional<cudaEvent_t> run_finished_;
555 #else // !USE_CUDA
556   bool run_finished_{};
557 #endif
558 
559   // Generated model uses this device index to create CUDA guards.
560   int32_t device_type_{};
561   int32_t device_idx_{};
562 };
563 
564 // Codegen-ed classes can derive from this to keep pointers to loaded kernels.
565 class AOTInductorModelKernelsBase {
566  public:
567   virtual ~AOTInductorModelKernelsBase() = default;
568 };
569 
570 class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
571  public:
572   AOTInductorModel(
573       std::shared_ptr<ConstantMap> constants_map,
574       std::shared_ptr<std::vector<ConstantHandle>> constants_array,
575       const std::string& device_str,
576       std::optional<std::string> cubin_dir);
577 
578   std::unordered_map<std::string, AtenTensorHandle> const_run_impl(
579       DeviceStreamType stream,
580       AOTIProxyExecutorHandle proxy_executor,
581       bool initialization = false);
582 
583   void _const_run_impl(
584       std::vector<AtenTensorHandle>& output_handles,
585       DeviceStreamType stream,
586       AOTIProxyExecutorHandle proxy_executor);
587 
588   void run_impl(
589       AtenTensorHandle*
590           input_handles, // array of input AtenTensorHandle; handles
591                          // are stolen; the array itself is borrowed
592       AtenTensorHandle*
593           output_handles, // array for writing output AtenTensorHandle; handles
594                           // will be stolen by the caller; the array itself is
595                           // borrowed
596       DeviceStreamType stream,
597       AOTIProxyExecutorHandle proxy_executor);
598 
599   template <typename Inputs, typename Outputs>
600   Outputs run_impl_minimal_arrayref_interface(
601       const Inputs& inputs,
602       DeviceStreamType stream,
603       AOTIProxyExecutorHandle proxy_executor);
604 
Create(std::shared_ptr<ConstantMap> constants_map,std::shared_ptr<std::vector<ConstantHandle>> constants_array,const std::string & device_str,std::optional<std::string> cubin_dir)605   static std::unique_ptr<AOTInductorModel> Create(
606       std::shared_ptr<ConstantMap> constants_map,
607       std::shared_ptr<std::vector<ConstantHandle>> constants_array,
608       const std::string& device_str,
609       std::optional<std::string> cubin_dir) {
610     return std::make_unique<AOTInductorModel>(
611         std::move(constants_map),
612         std::move(constants_array),
613         device_str,
614         std::move(cubin_dir));
615   }
616 
617  private:
618   std::unique_ptr<AOTInductorModelKernelsBase> kernels_;
619 };
620 
621 } // namespace torch::aot_inductor
622