1 #pragma once
2
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <ATen/core/ivalue.h>
5 #include <c10/macros/Export.h>
6 #include <nlohmann/json.hpp>
7 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
8 #include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
9 #include <iostream>
10 #include <utility>
11
12 namespace torch::aot_inductor {
13
14 enum class DynamicArgType : int {
15 TensorType = 0,
16 ListTensorType = 1,
17 ListOptionalTensorType = 2,
18 IntType = 3,
19 ListIntType = 4,
20 };
21
22 inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) {
23 os << static_cast<int>(arg_type);
24 return os;
25 }
26
isTensorType(DynamicArgType arg_type)27 inline bool isTensorType(DynamicArgType arg_type) {
28 return arg_type == DynamicArgType::TensorType ||
29 arg_type == DynamicArgType::ListTensorType ||
30 arg_type == DynamicArgType::ListOptionalTensorType;
31 }
32
33 struct OSSDynamicArg {
34 OSSDynamicArg(
35 int arg_index,
36 DynamicArgType arg_type,
37 int length,
38 std::optional<std::vector<std::string>> list_item_types = std::nullopt)
arg_indexOSSDynamicArg39 : arg_index(arg_index),
40 arg_type(arg_type),
41 length(length),
42 list_item_types(std::move(list_item_types)) {}
43 int arg_index;
44 DynamicArgType arg_type;
45 int length;
46 std::optional<std::vector<std::string>>
47 list_item_types; // only used for parsing list of optional tensors
48 };
49
50 struct OSSOpKernel {
OSSOpKernelOSSOpKernel51 OSSOpKernel(std::string target, c10::OperatorHandle op_handle)
52 : target_(std::move(target)), op_handle_(std::move(op_handle)) {}
53
54 std::string target_;
55 c10::OperatorHandle op_handle_;
56 std::vector<OSSDynamicArg> dynamic_args_;
57 std::vector<OSSDynamicArg> outputs_;
58 std::vector<c10::IValue> stack_;
59
num_output_tensorsOSSOpKernel60 int num_output_tensors() const {
61 int num_output_tensors = 0;
62 for (const auto& output : outputs_) {
63 if (isTensorType(output.arg_type)) {
64 num_output_tensors += output.length;
65 }
66 }
67 return num_output_tensors;
68 }
69 };
70
71 class OSSProxyExecutor : public ProxyExecutor {
72 public:
73 explicit OSSProxyExecutor(const std::string& json_path, bool is_cpu);
74
75 void call_function(
76 int extern_node_index,
77 int num_ints,
78 int64_t* flatten_int_args,
79 int num_tensors,
80 AtenTensorHandle* flatten_tensor_args) override;
81
82 private:
83 void prefill_stack_with_static_arguments(
84 int index,
85 at::TypePtr schema_arg_type,
86 const nlohmann::json& serialized_arg,
87 OSSOpKernel& op_kernel);
88
89 void get_input_info_from_serialized(
90 const std::vector<c10::Argument>& schema_args,
91 const nlohmann::json& serialized_node,
92 OSSOpKernel& op_kernel);
93
94 void get_output_info_from_serialized(
95 const std::vector<c10::Argument>& schema_returns,
96 const nlohmann::json& serialized_node,
97 OSSOpKernel& op_kernel);
98
99 std::vector<OSSOpKernel> op_kernels_;
100 std::unique_ptr<c10::Device> device_;
101 };
102
103 } // namespace torch::aot_inductor
104