xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <nlohmann/json.hpp>
2 #include <fstream>
3 #include <iostream>
4 
5 #include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
6 
7 namespace {
tensor_handle_to_tensor_pointer(AtenTensorHandle handle)8 at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
9   return reinterpret_cast<at::Tensor*>(handle);
10 }
11 } // namespace
12 
13 namespace torch::aot_inductor {
14 
prefill_stack_with_static_arguments(int index,at::TypePtr schema_arg_type,const nlohmann::json & serialized_arg,OSSOpKernel & op_kernel)15 void OSSProxyExecutor::prefill_stack_with_static_arguments(
16     int index,
17     at::TypePtr schema_arg_type,
18     const nlohmann::json& serialized_arg,
19     OSSOpKernel& op_kernel) {
20   auto& stack = op_kernel.stack_;
21   auto& dynamic_args = op_kernel.dynamic_args_;
22 
23   TORCH_CHECK(serialized_arg.size() == 1);
24   std::string serialized_arg_type = serialized_arg.begin().key();
25   auto& serialized_arg_val = serialized_arg.begin().value();
26 
27   switch (schema_arg_type->kind()) {
28     case c10::TypeKind::TensorType: {
29       TORCH_CHECK(serialized_arg_type == "as_tensor");
30       stack.emplace_back();
31       dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1);
32       break;
33     }
34     case c10::TypeKind::IntType: {
35       TORCH_CHECK(serialized_arg_type == "as_int");
36       stack.emplace_back(c10::IValue());
37       dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
38       break;
39     }
40     case c10::TypeKind::SymIntType: {
41       TORCH_CHECK(
42           serialized_arg_type == "as_int" ||
43           serialized_arg_type == "as_sym_int");
44       stack.emplace_back(c10::IValue());
45       dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
46       break;
47     }
48     case c10::TypeKind::FloatType: {
49       TORCH_CHECK(serialized_arg_type == "as_float");
50       stack.emplace_back(serialized_arg_val.get<double>());
51       break;
52     }
53     case c10::TypeKind::BoolType: {
54       TORCH_CHECK(serialized_arg_type == "as_bool");
55       stack.emplace_back(serialized_arg_val.get<bool>());
56       break;
57     }
58     case c10::TypeKind::NumberType: {
59       if (serialized_arg_type == "as_int") {
60         // Only int Scalar is treated as dynamic arg for now
61         stack.emplace_back();
62         dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
63       } else if (serialized_arg_type == "as_float") {
64         stack.emplace_back(serialized_arg_val.get<double>());
65       } else if (serialized_arg_type == "as_bool") {
66         stack.emplace_back(serialized_arg_val.get<bool>());
67       } else {
68         TORCH_CHECK(
69             false,
70             "Invalid serialized argument type found for Scalar input: ",
71             serialized_arg_type);
72       }
73       break;
74     }
75     case c10::TypeKind::StringType: {
76       TORCH_CHECK(serialized_arg_type == "as_string");
77       stack.emplace_back(serialized_arg_val.get<std::string>());
78       break;
79     }
80     case c10::TypeKind::DeviceObjType: {
81       TORCH_CHECK(serialized_arg_type == "as_device");
82 
83       std::string device_string = serialized_arg_val["type"].get<std::string>();
84       if (serialized_arg_val["index"].is_number()) {
85         device_string += ":" + serialized_arg_val["index"].get<std::string>();
86       }
87 
88       c10::Device device(device_string);
89 
90       if (device != *device_) {
91         VLOG(1) << "ProxyExecutor is using " << *device_ << " for "
92                 << op_kernel.target_ << " argument #" << index
93                 << ", which is different from the one serialized in thrift: "
94                 << device << ". Please ensure this is intentional.";
95       }
96 
97       stack.emplace_back(*device_);
98       break;
99     }
100     case c10::TypeKind::ListType: {
101       if (schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) {
102         TORCH_CHECK(serialized_arg_type == "as_tensors");
103         stack.emplace_back();
104         dynamic_args.emplace_back(
105             index, DynamicArgType::ListTensorType, serialized_arg_val.size());
106       } else if (schema_arg_type->isSubtypeOf(at::ListType::ofInts())) {
107         TORCH_CHECK(serialized_arg_type == "as_ints");
108         dynamic_args.emplace_back(
109             index, DynamicArgType::ListIntType, serialized_arg_val.size());
110         stack.emplace_back(c10::IValue());
111       } else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) {
112         TORCH_CHECK(
113             serialized_arg_type == "as_ints" ||
114             serialized_arg_type == "as_sym_ints");
115         dynamic_args.emplace_back(
116             index, DynamicArgType::ListIntType, serialized_arg_val.size());
117         stack.emplace_back(c10::IValue());
118       } else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) {
119         TORCH_CHECK(serialized_arg_type == "as_floats");
120         std::vector<double> ret;
121         for (const auto& arg : serialized_arg_val) {
122           ret.push_back(arg.get<double>());
123         }
124         stack.emplace_back(ret);
125       } else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) {
126         TORCH_CHECK(serialized_arg_type == "as_bools");
127         std::vector<bool> ret;
128         for (const auto& arg : serialized_arg_val) {
129           ret.push_back(arg.get<bool>());
130         }
131         stack.emplace_back(ret);
132       } else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) {
133         if (serialized_arg_type == "as_ints") {
134           dynamic_args.emplace_back(
135               index, DynamicArgType::ListIntType, serialized_arg_val.size());
136           stack.emplace_back(c10::IValue());
137         } else if (serialized_arg_type == "as_floats") {
138           std::vector<double> ret;
139           for (const auto& arg : serialized_arg_val) {
140             ret.push_back(arg);
141           }
142           stack.emplace_back(ret);
143         } else if (serialized_arg_type == "as_bools") {
144           std::vector<bool> ret;
145           for (const auto& arg : serialized_arg_val) {
146             ret.push_back(arg);
147           }
148           stack.emplace_back(ret);
149         } else {
150           TORCH_CHECK(
151               false,
152               "Invalid serialized argument type found for List[Scalar] ",
153               serialized_arg_type);
154         }
155       } else if (schema_arg_type->isSubtypeOf(
156                      at::ListType::ofOptionalTensors())) {
157         if (serialized_arg_type == "as_optional_tensors") {
158           std::vector<std::string> list_item_types;
159           for (const auto& arg : serialized_arg_val) {
160             list_item_types.push_back(arg.begin().key());
161           }
162           stack.emplace_back();
163           dynamic_args.emplace_back(
164               index,
165               DynamicArgType::ListOptionalTensorType,
166               serialized_arg_val.size(),
167               list_item_types);
168         } else if (serialized_arg_type == "as_tensors") {
169           stack.emplace_back();
170           dynamic_args.emplace_back(
171               index, DynamicArgType::ListTensorType, serialized_arg_val.size());
172         } else {
173           TORCH_CHECK(
174               false,
175               "Invalid serialized type found for argument of type `Tensor?[]`",
176               serialized_arg_type);
177         }
178       } else if (schema_arg_type->isSubtypeOf(at::ListType::ofStrings())) {
179         TORCH_CHECK(serialized_arg_type == "as_strings");
180         std::vector<std::string> ret;
181         for (const auto& arg : serialized_arg_val) {
182           ret.push_back(arg.get<std::string>());
183         }
184         stack.emplace_back(ret);
185       } else {
186         TORCH_CHECK(false, "NYI: Unsupported list type ", serialized_arg_type);
187       }
188       break;
189     }
190     case c10::TypeKind::OptionalType: {
191       auto inner_type =
192           schema_arg_type->castRaw<at::OptionalType>()->getElementType();
193 
194       if (serialized_arg_type == "as_none") {
195         stack.emplace_back(c10::nullopt);
196         if (inner_type->kind() == c10::TypeKind::TensorType) {
197           // Tensor is None
198           dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0);
199         } else if (
200             inner_type->kind() == c10::TypeKind::IntType ||
201             inner_type->kind() == c10::TypeKind::SymIntType) {
202           // Int or SymInt is None
203           dynamic_args.emplace_back(index, DynamicArgType::IntType, 0);
204         } else if (
205             inner_type->kind() == c10::TypeKind::ListType &&
206             schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) {
207           // List[Tensor] is None
208           dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, 0);
209         } else if (
210             inner_type->kind() == c10::TypeKind::ListType &&
211             schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) {
212           // List[SymInt] is None
213           dynamic_args.emplace_back(index, DynamicArgType::ListIntType, 0);
214         }
215       } else {
216         prefill_stack_with_static_arguments(
217             index, inner_type, serialized_arg, op_kernel);
218       }
219       break;
220     }
221     // TODO: handle the other input types
222     default:
223       TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type);
224   }
225 }
226 
227 // Populates op_kernel.stack_, op_kernel.dynamic_args_
get_input_info_from_serialized(const std::vector<c10::Argument> & schema_args,const nlohmann::json & serialized_node,OSSOpKernel & op_kernel)228 void OSSProxyExecutor::get_input_info_from_serialized(
229     const std::vector<c10::Argument>& schema_args,
230     const nlohmann::json& serialized_node,
231     OSSOpKernel& op_kernel) {
232   int index = 0;
233   for (const auto& named_argument : serialized_node["inputs"]) {
234     const auto& arg = named_argument["arg"];
235     auto& schema_arg = schema_args[index];
236 
237     prefill_stack_with_static_arguments(
238         index++, schema_arg.real_type(), arg, op_kernel);
239   }
240 
241   // TODO: prefill default values
242 }
243 
244 // Populates op_kernel.outputs_
get_output_info_from_serialized(const std::vector<c10::Argument> & schema_returns,const nlohmann::json & serialized_node,OSSOpKernel & op_kernel)245 void OSSProxyExecutor::get_output_info_from_serialized(
246     const std::vector<c10::Argument>& schema_returns,
247     const nlohmann::json& serialized_node,
248     OSSOpKernel& op_kernel) {
249   std::vector<OSSDynamicArg>& outputs = op_kernel.outputs_;
250 
251   TORCH_CHECK(
252       schema_returns.size() == serialized_node["outputs"].size(),
253       "Serialized node doesn't match op's schema outputs.");
254 
255   size_t output_index = 0;
256   for (const auto& serialized_output : serialized_node["outputs"]) {
257     TORCH_CHECK(serialized_output.size() == 1);
258     std::string serialized_output_type = serialized_output.begin().key();
259     auto& serialized_output_val = serialized_output.begin().value();
260 
261     auto& schema_return = schema_returns[output_index];
262     at::TypePtr schema_return_type = schema_return.real_type();
263 
264     switch (schema_return_type->kind()) {
265       case c10::TypeKind::TensorType: {
266         TORCH_CHECK(
267             serialized_output_type == "as_tensor",
268             serialized_node["target"],
269             " got serialized_output_type of ",
270             serialized_output_type);
271         outputs.emplace_back(output_index, DynamicArgType::TensorType, 1);
272         break;
273       }
274       case c10::TypeKind::ListType: {
275         if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) {
276           TORCH_CHECK(
277               serialized_output_type == "as_tensors",
278               serialized_node["target"],
279               " got serialized_output_type of ",
280               serialized_output_type);
281           outputs.emplace_back(
282               output_index,
283               DynamicArgType::ListTensorType,
284               serialized_output_val.size());
285         } else {
286           TORCH_CHECK(
287               false,
288               "Unsupported return list type ",
289               schema_return_type->repr_str());
290         }
291         break;
292       }
293       default: {
294         TORCH_CHECK(
295             false, "Unsupported return type ", schema_return_type->repr_str());
296       }
297     }
298 
299     output_index++;
300   }
301 }
302 
OSSProxyExecutor(const std::string & json_path,bool is_cpu)303 OSSProxyExecutor::OSSProxyExecutor(const std::string& json_path, bool is_cpu) {
304   if (is_cpu) {
305     device_ = std::make_unique<c10::Device>(c10::DeviceType::CPU);
306   } else {
307     int device_idx = -1;
308     device_ = std::make_unique<c10::Device>(c10::DeviceType::CUDA, device_idx);
309   }
310 
311   std::string extern_kernel_nodes_serialized;
312 
313   std::ifstream json_file(json_path);
314   TORCH_CHECK(json_file.is_open());
315 
316   // Parse file into a json object
317   nlohmann::json json_obj;
318   json_file >> json_obj;
319 
320   // Access data
321   for (auto const& serialized_extern_node : json_obj["nodes"]) {
322     auto const& serialized_node = serialized_extern_node["node"];
323 
324     const std::string& target = serialized_node["target"];
325 
326     std::string opName;
327     std::string overloadName;
328     size_t pos = target.find('.');
329     if (pos == std::string::npos) {
330       opName = target;
331       overloadName = "";
332     } else {
333       // There should be no more periods
334       size_t pos2 = target.find('.', pos);
335       TORCH_CHECK(pos2 == std::string::npos);
336 
337       opName = target.substr(0, pos);
338       overloadName = target.substr(pos + 1, target.length() - pos);
339     }
340 
341     c10::OperatorHandle op_handle =
342         c10::Dispatcher::singleton().findSchemaOrThrow(
343             opName.c_str(), overloadName.c_str());
344     const c10::FunctionSchema& schema = op_handle.schema();
345 
346     const auto& schema_args = schema.arguments();
347     const auto& schema_returns = schema.returns();
348 
349     OSSOpKernel op_kernel(target, op_handle);
350     get_input_info_from_serialized(schema_args, serialized_node, op_kernel);
351     get_output_info_from_serialized(schema_returns, serialized_node, op_kernel);
352 
353     op_kernels_.emplace_back(std::move(op_kernel));
354   }
355 }
356 
call_function(int extern_node_index,int num_ints,int64_t * flatten_int_args,int num_tensors,AtenTensorHandle * flatten_tensor_args)357 void OSSProxyExecutor::call_function(
358     int extern_node_index,
359     int num_ints,
360     int64_t* flatten_int_args,
361     int num_tensors,
362     AtenTensorHandle* flatten_tensor_args) {
363   TORCH_CHECK(
364       extern_node_index < static_cast<int>(op_kernels_.size()),
365       "Invalid extern node index");
366   OSSOpKernel& op_kernel = op_kernels_[extern_node_index];
367 
368   std::vector<c10::IValue> stack = op_kernel.stack_;
369   auto& dynamic_args = op_kernel.dynamic_args_;
370 
371   int tensor_id = 0;
372   int int_id = 0;
373   for (auto& dynamic_arg : dynamic_args) {
374     int arg_index = dynamic_arg.arg_index;
375     DynamicArgType dynamic_arg_type = dynamic_arg.arg_type;
376     int length = dynamic_arg.length;
377 
378     if (length == 0) {
379       continue;
380     }
381 
382     switch (dynamic_arg_type) {
383       case DynamicArgType::TensorType: {
384         at::Tensor* tensor =
385             tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
386         stack[arg_index] = *tensor;
387         break;
388       }
389       case DynamicArgType::IntType: {
390         int64_t val = flatten_int_args[int_id++];
391         stack[arg_index] = val;
392         break;
393       }
394       case DynamicArgType::ListTensorType: {
395         std::vector<at::Tensor> tensor_list;
396         for (int j = 0; j < length; j++) {
397           at::Tensor* tensor =
398               tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
399           tensor_list.push_back(*tensor);
400         }
401         stack[arg_index] = tensor_list;
402         break;
403       }
404       case DynamicArgType::ListOptionalTensorType: {
405         std::vector<std::optional<at::Tensor>> optional_tensor_list;
406         auto& list_item_types = dynamic_arg.list_item_types;
407         TORCH_CHECK(
408             list_item_types.has_value(),
409             "Could not find list of item types for optional tensor list input");
410 
411         for (std::string item_type : list_item_types.value()) {
412           if (item_type == "as_tensor") {
413             at::Tensor* tensor = tensor_handle_to_tensor_pointer(
414                 flatten_tensor_args[tensor_id++]);
415             optional_tensor_list.emplace_back(*tensor);
416           } else if (item_type == "as_none") {
417             optional_tensor_list.emplace_back(c10::nullopt);
418           }
419         }
420         stack[arg_index] = optional_tensor_list;
421         break;
422       }
423       case DynamicArgType::ListIntType: {
424         std::vector<int64_t> vals;
425         for (int j = 0; j < length; j++) {
426           vals.push_back(flatten_int_args[int_id++]);
427         }
428         stack[arg_index] = vals;
429         break;
430       }
431       default:
432         TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type);
433     }
434   }
435 
436   int num_output_tensors = op_kernel.num_output_tensors();
437   TORCH_CHECK(
438       tensor_id == num_tensors - num_output_tensors,
439       "Mismatch between tensors consumed and num of input tensor, got tensor_id = .",
440       tensor_id,
441       ", expected num = ",
442       num_tensors - num_output_tensors);
443   TORCH_CHECK(
444       int_id == num_ints,
445       "Mismatch between ints consumed and num_ints, got int_id = ",
446       int_id,
447       ", num_ints = ",
448       num_ints);
449 
450   // Call the op with the prepared stack.
451   const c10::OperatorHandle& op = op_kernel.op_handle_;
452   op.callBoxed(stack);
453 
454   const c10::FunctionSchema& schema = op.schema();
455   const auto& schema_returns = schema.returns();
456 
457   TORCH_CHECK(op_kernel.outputs_.size() == stack.size());
458   // TODO: what about optional outputs? This assert may not hold
459   TORCH_CHECK(stack.size() == schema_returns.size());
460 
461   int index = 0;
462   for (const auto& schema_return : schema_returns) {
463     if (schema_return.type()->kind() == c10::TypeKind::TensorType) {
464       at::Tensor* tensor =
465           tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
466       *tensor = stack[index++].toTensor();
467     } else if (
468         schema_return.type()->kind() == c10::TypeKind::ListType &&
469         schema_return.type()->isSubtypeOf(at::ListType::ofTensors())) {
470       auto tensors = stack[index++].toTensorList();
471       for (size_t i = 0; i < tensors.size(); ++i) {
472         at::Tensor* tensor =
473             tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
474         *tensor = tensors[i];
475       }
476     } else {
477       TORCH_CHECK(
478           false,
479           "NYI: Unsupported return type for schema: ",
480           schema_return.type()->repr_str());
481     }
482   }
483 
484   TORCH_CHECK(
485       tensor_id == num_tensors,
486       "Mismatch between tensors consumed and num_tensors, got tensor_id = ",
487       tensor_id,
488       ", expected num = ",
489       num_tensors);
490 }
491 
492 } // namespace torch::aot_inductor
493