xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <ATen/core/ivalue.h>
5 #include <c10/macros/Export.h>
6 #include <nlohmann/json.hpp>
7 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
8 #include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
9 #include <iostream>
10 #include <utility>
11 
12 namespace torch::aot_inductor {
13 
14 enum class DynamicArgType : int {
15   TensorType = 0,
16   ListTensorType = 1,
17   ListOptionalTensorType = 2,
18   IntType = 3,
19   ListIntType = 4,
20 };
21 
22 inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) {
23   os << static_cast<int>(arg_type);
24   return os;
25 }
26 
isTensorType(DynamicArgType arg_type)27 inline bool isTensorType(DynamicArgType arg_type) {
28   return arg_type == DynamicArgType::TensorType ||
29       arg_type == DynamicArgType::ListTensorType ||
30       arg_type == DynamicArgType::ListOptionalTensorType;
31 }
32 
33 struct OSSDynamicArg {
34   OSSDynamicArg(
35       int arg_index,
36       DynamicArgType arg_type,
37       int length,
38       std::optional<std::vector<std::string>> list_item_types = std::nullopt)
arg_indexOSSDynamicArg39       : arg_index(arg_index),
40         arg_type(arg_type),
41         length(length),
42         list_item_types(std::move(list_item_types)) {}
43   int arg_index;
44   DynamicArgType arg_type;
45   int length;
46   std::optional<std::vector<std::string>>
47       list_item_types; // only used for parsing list of optional tensors
48 };
49 
50 struct OSSOpKernel {
OSSOpKernelOSSOpKernel51   OSSOpKernel(std::string target, c10::OperatorHandle op_handle)
52       : target_(std::move(target)), op_handle_(std::move(op_handle)) {}
53 
54   std::string target_;
55   c10::OperatorHandle op_handle_;
56   std::vector<OSSDynamicArg> dynamic_args_;
57   std::vector<OSSDynamicArg> outputs_;
58   std::vector<c10::IValue> stack_;
59 
num_output_tensorsOSSOpKernel60   int num_output_tensors() const {
61     int num_output_tensors = 0;
62     for (const auto& output : outputs_) {
63       if (isTensorType(output.arg_type)) {
64         num_output_tensors += output.length;
65       }
66     }
67     return num_output_tensors;
68   }
69 };
70 
71 class OSSProxyExecutor : public ProxyExecutor {
72  public:
73   explicit OSSProxyExecutor(const std::string& json_path, bool is_cpu);
74 
75   void call_function(
76       int extern_node_index,
77       int num_ints,
78       int64_t* flatten_int_args,
79       int num_tensors,
80       AtenTensorHandle* flatten_tensor_args) override;
81 
82  private:
83   void prefill_stack_with_static_arguments(
84       int index,
85       at::TypePtr schema_arg_type,
86       const nlohmann::json& serialized_arg,
87       OSSOpKernel& op_kernel);
88 
89   void get_input_info_from_serialized(
90       const std::vector<c10::Argument>& schema_args,
91       const nlohmann::json& serialized_node,
92       OSSOpKernel& op_kernel);
93 
94   void get_output_info_from_serialized(
95       const std::vector<c10::Argument>& schema_returns,
96       const nlohmann::json& serialized_node,
97       OSSOpKernel& op_kernel);
98 
99   std::vector<OSSOpKernel> op_kernels_;
100   std::unique_ptr<c10::Device> device_;
101 };
102 
103 } // namespace torch::aot_inductor
104