1 #if !defined(C10_MOBILE) && !defined(ANDROID) 2 #pragma once 3 4 #include <ATen/ATen.h> 5 #include <ATen/core/boxing/KernelFunction.h> 6 #include <ATen/core/function_schema.h> 7 8 #include <torch/csrc/dynamo/guards.h> 9 #include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h> 10 #include <torch/csrc/inductor/aoti_runner/model_container_runner.h> 11 #include <torch/csrc/utils/pybind.h> 12 13 #include <string> 14 15 namespace torch::inductor { 16 17 // Represent AOTI kernel. It contains all the parameter metadata of the kernel 18 // and the AOTI model runner. 19 struct AOTIKernelMetadata { 20 // Represent all the parameters of AOTI kernel 21 std::vector<ParameterMetadata> parameter_metadata_list_; 22 // AOTI model runner to run the AOTI kernel 23 std::shared_ptr<AOTIModelContainerRunner> kernel_runner_; AOTIKernelMetadataAOTIKernelMetadata24 AOTIKernelMetadata() : parameter_metadata_list_(), kernel_runner_(nullptr) {} 25 26 // Check whether the given parameter metadata list is the same as the 27 // parameter metadata list of the AOTI kernel. checkAOTIKernelMetadata28 bool check( 29 const std::vector<ParameterMetadata>& parameter_metadata_list) const { 30 if (parameter_metadata_list_.size() != parameter_metadata_list.size()) { 31 return false; 32 } 33 34 for (size_t i = 0; i < parameter_metadata_list_.size(); ++i) { 35 if (parameter_metadata_list_[i] == parameter_metadata_list[i]) { 36 continue; 37 } else { 38 return false; 39 } 40 } 41 42 return true; 43 } 44 }; 45 46 // The AOTIPythonKernelHolder class uses the AOT Inductor to generate a kernel 47 // for a specified operation. To speed up this process, the generated kernel 48 // library is cached on disk. Detailed information from the input tensors is 49 // used as the key for caching the kernel library. On subsequent runs, these 50 // input tensors are used to search the cache. If a cache hit occurs, the cached 51 // kernel library is loaded and executed. If a cache miss occurs, the AOT 52 // Inductor is called again to generate the kernel library. 53 class AOTIPythonKernelHolder : public c10::OperatorKernel { 54 // A DispatchKey object that represents the dispatch key for the kernel. 55 c10::DispatchKey dispatch_key_; 56 // Namespace of the kernel. 57 std::string ns_; 58 // Name of the operation the kernel performs. 59 std::string op_name_with_overload_; 60 // The device on which the kernel is to be executed. 61 c10::Device device_; 62 // The Python interpreter to get OpOverload object with the given op_name and 63 // op_overload_name. 64 c10::impl::PyInterpreter* pyinterpreter_; 65 // Cache the produced kernels by AOTI and its metadata 66 std::vector<AOTIKernelMetadata> aoti_kernel_cache_; 67 68 public: 69 AOTIPythonKernelHolder( 70 c10::DispatchKey dispatch_key, 71 c10::string_view ns, 72 c10::string_view op_name_with_overload); 73 74 void operator()( 75 const c10::OperatorHandle& op, 76 c10::DispatchKeySet keyset, 77 torch::jit::Stack* stack); 78 79 private: 80 bool cache_lookup( 81 const c10::OperatorHandle& op, 82 const c10::DispatchKeySet& keyset, 83 const torch::jit::Stack* stack, 84 AOTIKernelMetadata& aoti_kernel_metadata); 85 void cache_miss( 86 const c10::OperatorHandle& op, 87 const c10::DispatchKeySet& keyset, 88 torch::jit::Stack* stack); 89 void cache_hit( 90 const AOTIKernelMetadata& aoti_kernel_metadata, 91 const c10::OperatorHandle& op, 92 const c10::DispatchKeySet& keyset, 93 torch::jit::Stack* stack); 94 // Invoke python utility function on the Inductor side to produce AOTI kernel 95 // for the given operation. 96 // Inductor utility function - 97 // torch._inductor.utils.aoti_compile_with_persistent_cache 98 std::string produce_aoti_kernel_lib( 99 const c10::OperatorHandle& op, 100 const c10::DispatchKeySet& keyset, 101 const torch::jit::Stack* stack); 102 // Invoke python utility function on the Inductor side to load AOTI kernel for 103 // the given operation. 104 // Inductor utility function - torch._inductor.utils.load_aoti_eager_cache 105 void init_aoti_kernel_cache(); 106 // Load the AOTIModelContainerRunner object from the given file path. 107 std::shared_ptr<AOTIModelContainerRunner> load_aoti_model_runner( 108 const std::string&); 109 }; 110 111 } // namespace torch::inductor 112 #endif 113