xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_eager/kernel_holder.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #pragma once
3 
4 #include <ATen/ATen.h>
5 #include <ATen/core/boxing/KernelFunction.h>
6 #include <ATen/core/function_schema.h>
7 
8 #include <torch/csrc/dynamo/guards.h>
9 #include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
10 #include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
11 #include <torch/csrc/utils/pybind.h>
12 
13 #include <string>
14 
15 namespace torch::inductor {
16 
17 // Represent AOTI kernel. It contains all the parameter metadata of the kernel
18 // and the AOTI model runner.
19 struct AOTIKernelMetadata {
20   // Represent all the parameters of AOTI kernel
21   std::vector<ParameterMetadata> parameter_metadata_list_;
22   // AOTI model runner to run the AOTI kernel
23   std::shared_ptr<AOTIModelContainerRunner> kernel_runner_;
AOTIKernelMetadataAOTIKernelMetadata24   AOTIKernelMetadata() : parameter_metadata_list_(), kernel_runner_(nullptr) {}
25 
26   // Check whether the given parameter metadata list is the same as the
27   // parameter metadata list of the AOTI kernel.
checkAOTIKernelMetadata28   bool check(
29       const std::vector<ParameterMetadata>& parameter_metadata_list) const {
30     if (parameter_metadata_list_.size() != parameter_metadata_list.size()) {
31       return false;
32     }
33 
34     for (size_t i = 0; i < parameter_metadata_list_.size(); ++i) {
35       if (parameter_metadata_list_[i] == parameter_metadata_list[i]) {
36         continue;
37       } else {
38         return false;
39       }
40     }
41 
42     return true;
43   }
44 };
45 
46 // The AOTIPythonKernelHolder class uses the AOT Inductor to generate a kernel
47 // for a specified operation. To speed up this process, the generated kernel
48 // library is cached on disk. Detailed information from the input tensors is
49 // used as the key for caching the kernel library. On subsequent runs, these
50 // input tensors are used to search the cache. If a cache hit occurs, the cached
51 // kernel library is loaded and executed. If a cache miss occurs, the AOT
52 // Inductor is called again to generate the kernel library.
53 class AOTIPythonKernelHolder : public c10::OperatorKernel {
54   // A DispatchKey object that represents the dispatch key for the kernel.
55   c10::DispatchKey dispatch_key_;
56   // Namespace of the kernel.
57   std::string ns_;
58   // Name of the operation the kernel performs.
59   std::string op_name_with_overload_;
60   // The device on which the kernel is to be executed.
61   c10::Device device_;
62   // The Python interpreter to get OpOverload object with the given op_name and
63   // op_overload_name.
64   c10::impl::PyInterpreter* pyinterpreter_;
65   // Cache the produced kernels by AOTI and its metadata
66   std::vector<AOTIKernelMetadata> aoti_kernel_cache_;
67 
68  public:
69   AOTIPythonKernelHolder(
70       c10::DispatchKey dispatch_key,
71       c10::string_view ns,
72       c10::string_view op_name_with_overload);
73 
74   void operator()(
75       const c10::OperatorHandle& op,
76       c10::DispatchKeySet keyset,
77       torch::jit::Stack* stack);
78 
79  private:
80   bool cache_lookup(
81       const c10::OperatorHandle& op,
82       const c10::DispatchKeySet& keyset,
83       const torch::jit::Stack* stack,
84       AOTIKernelMetadata& aoti_kernel_metadata);
85   void cache_miss(
86       const c10::OperatorHandle& op,
87       const c10::DispatchKeySet& keyset,
88       torch::jit::Stack* stack);
89   void cache_hit(
90       const AOTIKernelMetadata& aoti_kernel_metadata,
91       const c10::OperatorHandle& op,
92       const c10::DispatchKeySet& keyset,
93       torch::jit::Stack* stack);
94   // Invoke python utility function on the Inductor side to produce AOTI kernel
95   // for the given operation.
96   //   Inductor utility function -
97   //   torch._inductor.utils.aoti_compile_with_persistent_cache
98   std::string produce_aoti_kernel_lib(
99       const c10::OperatorHandle& op,
100       const c10::DispatchKeySet& keyset,
101       const torch::jit::Stack* stack);
102   // Invoke python utility function on the Inductor side to load AOTI kernel for
103   // the given operation.
104   //   Inductor utility function - torch._inductor.utils.load_aoti_eager_cache
105   void init_aoti_kernel_cache();
106   // Load the AOTIModelContainerRunner object from the given file path.
107   std::shared_ptr<AOTIModelContainerRunner> load_aoti_model_runner(
108       const std::string&);
109 };
110 
111 } // namespace torch::inductor
112 #endif
113