1 #include <nlohmann/json.hpp>
2 #include <fstream>
3 #include <iostream>
4
5 #include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
6
7 namespace {
tensor_handle_to_tensor_pointer(AtenTensorHandle handle)8 at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
9 return reinterpret_cast<at::Tensor*>(handle);
10 }
11 } // namespace
12
13 namespace torch::aot_inductor {
14
prefill_stack_with_static_arguments(int index,at::TypePtr schema_arg_type,const nlohmann::json & serialized_arg,OSSOpKernel & op_kernel)15 void OSSProxyExecutor::prefill_stack_with_static_arguments(
16 int index,
17 at::TypePtr schema_arg_type,
18 const nlohmann::json& serialized_arg,
19 OSSOpKernel& op_kernel) {
20 auto& stack = op_kernel.stack_;
21 auto& dynamic_args = op_kernel.dynamic_args_;
22
23 TORCH_CHECK(serialized_arg.size() == 1);
24 std::string serialized_arg_type = serialized_arg.begin().key();
25 auto& serialized_arg_val = serialized_arg.begin().value();
26
27 switch (schema_arg_type->kind()) {
28 case c10::TypeKind::TensorType: {
29 TORCH_CHECK(serialized_arg_type == "as_tensor");
30 stack.emplace_back();
31 dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1);
32 break;
33 }
34 case c10::TypeKind::IntType: {
35 TORCH_CHECK(serialized_arg_type == "as_int");
36 stack.emplace_back(c10::IValue());
37 dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
38 break;
39 }
40 case c10::TypeKind::SymIntType: {
41 TORCH_CHECK(
42 serialized_arg_type == "as_int" ||
43 serialized_arg_type == "as_sym_int");
44 stack.emplace_back(c10::IValue());
45 dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
46 break;
47 }
48 case c10::TypeKind::FloatType: {
49 TORCH_CHECK(serialized_arg_type == "as_float");
50 stack.emplace_back(serialized_arg_val.get<double>());
51 break;
52 }
53 case c10::TypeKind::BoolType: {
54 TORCH_CHECK(serialized_arg_type == "as_bool");
55 stack.emplace_back(serialized_arg_val.get<bool>());
56 break;
57 }
58 case c10::TypeKind::NumberType: {
59 if (serialized_arg_type == "as_int") {
60 // Only int Scalar is treated as dynamic arg for now
61 stack.emplace_back();
62 dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
63 } else if (serialized_arg_type == "as_float") {
64 stack.emplace_back(serialized_arg_val.get<double>());
65 } else if (serialized_arg_type == "as_bool") {
66 stack.emplace_back(serialized_arg_val.get<bool>());
67 } else {
68 TORCH_CHECK(
69 false,
70 "Invalid serialized argument type found for Scalar input: ",
71 serialized_arg_type);
72 }
73 break;
74 }
75 case c10::TypeKind::StringType: {
76 TORCH_CHECK(serialized_arg_type == "as_string");
77 stack.emplace_back(serialized_arg_val.get<std::string>());
78 break;
79 }
80 case c10::TypeKind::DeviceObjType: {
81 TORCH_CHECK(serialized_arg_type == "as_device");
82
83 std::string device_string = serialized_arg_val["type"].get<std::string>();
84 if (serialized_arg_val["index"].is_number()) {
85 device_string += ":" + serialized_arg_val["index"].get<std::string>();
86 }
87
88 c10::Device device(device_string);
89
90 if (device != *device_) {
91 VLOG(1) << "ProxyExecutor is using " << *device_ << " for "
92 << op_kernel.target_ << " argument #" << index
93 << ", which is different from the one serialized in thrift: "
94 << device << ". Please ensure this is intentional.";
95 }
96
97 stack.emplace_back(*device_);
98 break;
99 }
100 case c10::TypeKind::ListType: {
101 if (schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) {
102 TORCH_CHECK(serialized_arg_type == "as_tensors");
103 stack.emplace_back();
104 dynamic_args.emplace_back(
105 index, DynamicArgType::ListTensorType, serialized_arg_val.size());
106 } else if (schema_arg_type->isSubtypeOf(at::ListType::ofInts())) {
107 TORCH_CHECK(serialized_arg_type == "as_ints");
108 dynamic_args.emplace_back(
109 index, DynamicArgType::ListIntType, serialized_arg_val.size());
110 stack.emplace_back(c10::IValue());
111 } else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) {
112 TORCH_CHECK(
113 serialized_arg_type == "as_ints" ||
114 serialized_arg_type == "as_sym_ints");
115 dynamic_args.emplace_back(
116 index, DynamicArgType::ListIntType, serialized_arg_val.size());
117 stack.emplace_back(c10::IValue());
118 } else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) {
119 TORCH_CHECK(serialized_arg_type == "as_floats");
120 std::vector<double> ret;
121 for (const auto& arg : serialized_arg_val) {
122 ret.push_back(arg.get<double>());
123 }
124 stack.emplace_back(ret);
125 } else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) {
126 TORCH_CHECK(serialized_arg_type == "as_bools");
127 std::vector<bool> ret;
128 for (const auto& arg : serialized_arg_val) {
129 ret.push_back(arg.get<bool>());
130 }
131 stack.emplace_back(ret);
132 } else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) {
133 if (serialized_arg_type == "as_ints") {
134 dynamic_args.emplace_back(
135 index, DynamicArgType::ListIntType, serialized_arg_val.size());
136 stack.emplace_back(c10::IValue());
137 } else if (serialized_arg_type == "as_floats") {
138 std::vector<double> ret;
139 for (const auto& arg : serialized_arg_val) {
140 ret.push_back(arg);
141 }
142 stack.emplace_back(ret);
143 } else if (serialized_arg_type == "as_bools") {
144 std::vector<bool> ret;
145 for (const auto& arg : serialized_arg_val) {
146 ret.push_back(arg);
147 }
148 stack.emplace_back(ret);
149 } else {
150 TORCH_CHECK(
151 false,
152 "Invalid serialized argument type found for List[Scalar] ",
153 serialized_arg_type);
154 }
155 } else if (schema_arg_type->isSubtypeOf(
156 at::ListType::ofOptionalTensors())) {
157 if (serialized_arg_type == "as_optional_tensors") {
158 std::vector<std::string> list_item_types;
159 for (const auto& arg : serialized_arg_val) {
160 list_item_types.push_back(arg.begin().key());
161 }
162 stack.emplace_back();
163 dynamic_args.emplace_back(
164 index,
165 DynamicArgType::ListOptionalTensorType,
166 serialized_arg_val.size(),
167 list_item_types);
168 } else if (serialized_arg_type == "as_tensors") {
169 stack.emplace_back();
170 dynamic_args.emplace_back(
171 index, DynamicArgType::ListTensorType, serialized_arg_val.size());
172 } else {
173 TORCH_CHECK(
174 false,
175 "Invalid serialized type found for argument of type `Tensor?[]`",
176 serialized_arg_type);
177 }
178 } else if (schema_arg_type->isSubtypeOf(at::ListType::ofStrings())) {
179 TORCH_CHECK(serialized_arg_type == "as_strings");
180 std::vector<std::string> ret;
181 for (const auto& arg : serialized_arg_val) {
182 ret.push_back(arg.get<std::string>());
183 }
184 stack.emplace_back(ret);
185 } else {
186 TORCH_CHECK(false, "NYI: Unsupported list type ", serialized_arg_type);
187 }
188 break;
189 }
190 case c10::TypeKind::OptionalType: {
191 auto inner_type =
192 schema_arg_type->castRaw<at::OptionalType>()->getElementType();
193
194 if (serialized_arg_type == "as_none") {
195 stack.emplace_back(c10::nullopt);
196 if (inner_type->kind() == c10::TypeKind::TensorType) {
197 // Tensor is None
198 dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0);
199 } else if (
200 inner_type->kind() == c10::TypeKind::IntType ||
201 inner_type->kind() == c10::TypeKind::SymIntType) {
202 // Int or SymInt is None
203 dynamic_args.emplace_back(index, DynamicArgType::IntType, 0);
204 } else if (
205 inner_type->kind() == c10::TypeKind::ListType &&
206 schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) {
207 // List[Tensor] is None
208 dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, 0);
209 } else if (
210 inner_type->kind() == c10::TypeKind::ListType &&
211 schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) {
212 // List[SymInt] is None
213 dynamic_args.emplace_back(index, DynamicArgType::ListIntType, 0);
214 }
215 } else {
216 prefill_stack_with_static_arguments(
217 index, inner_type, serialized_arg, op_kernel);
218 }
219 break;
220 }
221 // TODO: handle the other input types
222 default:
223 TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type);
224 }
225 }
226
227 // Populates op_kernel.stack_, op_kernel.dynamic_args_
get_input_info_from_serialized(const std::vector<c10::Argument> & schema_args,const nlohmann::json & serialized_node,OSSOpKernel & op_kernel)228 void OSSProxyExecutor::get_input_info_from_serialized(
229 const std::vector<c10::Argument>& schema_args,
230 const nlohmann::json& serialized_node,
231 OSSOpKernel& op_kernel) {
232 int index = 0;
233 for (const auto& named_argument : serialized_node["inputs"]) {
234 const auto& arg = named_argument["arg"];
235 auto& schema_arg = schema_args[index];
236
237 prefill_stack_with_static_arguments(
238 index++, schema_arg.real_type(), arg, op_kernel);
239 }
240
241 // TODO: prefill default values
242 }
243
244 // Populates op_kernel.outputs_
get_output_info_from_serialized(const std::vector<c10::Argument> & schema_returns,const nlohmann::json & serialized_node,OSSOpKernel & op_kernel)245 void OSSProxyExecutor::get_output_info_from_serialized(
246 const std::vector<c10::Argument>& schema_returns,
247 const nlohmann::json& serialized_node,
248 OSSOpKernel& op_kernel) {
249 std::vector<OSSDynamicArg>& outputs = op_kernel.outputs_;
250
251 TORCH_CHECK(
252 schema_returns.size() == serialized_node["outputs"].size(),
253 "Serialized node doesn't match op's schema outputs.");
254
255 size_t output_index = 0;
256 for (const auto& serialized_output : serialized_node["outputs"]) {
257 TORCH_CHECK(serialized_output.size() == 1);
258 std::string serialized_output_type = serialized_output.begin().key();
259 auto& serialized_output_val = serialized_output.begin().value();
260
261 auto& schema_return = schema_returns[output_index];
262 at::TypePtr schema_return_type = schema_return.real_type();
263
264 switch (schema_return_type->kind()) {
265 case c10::TypeKind::TensorType: {
266 TORCH_CHECK(
267 serialized_output_type == "as_tensor",
268 serialized_node["target"],
269 " got serialized_output_type of ",
270 serialized_output_type);
271 outputs.emplace_back(output_index, DynamicArgType::TensorType, 1);
272 break;
273 }
274 case c10::TypeKind::ListType: {
275 if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) {
276 TORCH_CHECK(
277 serialized_output_type == "as_tensors",
278 serialized_node["target"],
279 " got serialized_output_type of ",
280 serialized_output_type);
281 outputs.emplace_back(
282 output_index,
283 DynamicArgType::ListTensorType,
284 serialized_output_val.size());
285 } else {
286 TORCH_CHECK(
287 false,
288 "Unsupported return list type ",
289 schema_return_type->repr_str());
290 }
291 break;
292 }
293 default: {
294 TORCH_CHECK(
295 false, "Unsupported return type ", schema_return_type->repr_str());
296 }
297 }
298
299 output_index++;
300 }
301 }
302
OSSProxyExecutor(const std::string & json_path,bool is_cpu)303 OSSProxyExecutor::OSSProxyExecutor(const std::string& json_path, bool is_cpu) {
304 if (is_cpu) {
305 device_ = std::make_unique<c10::Device>(c10::DeviceType::CPU);
306 } else {
307 int device_idx = -1;
308 device_ = std::make_unique<c10::Device>(c10::DeviceType::CUDA, device_idx);
309 }
310
311 std::string extern_kernel_nodes_serialized;
312
313 std::ifstream json_file(json_path);
314 TORCH_CHECK(json_file.is_open());
315
316 // Parse file into a json object
317 nlohmann::json json_obj;
318 json_file >> json_obj;
319
320 // Access data
321 for (auto const& serialized_extern_node : json_obj["nodes"]) {
322 auto const& serialized_node = serialized_extern_node["node"];
323
324 const std::string& target = serialized_node["target"];
325
326 std::string opName;
327 std::string overloadName;
328 size_t pos = target.find('.');
329 if (pos == std::string::npos) {
330 opName = target;
331 overloadName = "";
332 } else {
333 // There should be no more periods
334 size_t pos2 = target.find('.', pos);
335 TORCH_CHECK(pos2 == std::string::npos);
336
337 opName = target.substr(0, pos);
338 overloadName = target.substr(pos + 1, target.length() - pos);
339 }
340
341 c10::OperatorHandle op_handle =
342 c10::Dispatcher::singleton().findSchemaOrThrow(
343 opName.c_str(), overloadName.c_str());
344 const c10::FunctionSchema& schema = op_handle.schema();
345
346 const auto& schema_args = schema.arguments();
347 const auto& schema_returns = schema.returns();
348
349 OSSOpKernel op_kernel(target, op_handle);
350 get_input_info_from_serialized(schema_args, serialized_node, op_kernel);
351 get_output_info_from_serialized(schema_returns, serialized_node, op_kernel);
352
353 op_kernels_.emplace_back(std::move(op_kernel));
354 }
355 }
356
call_function(int extern_node_index,int num_ints,int64_t * flatten_int_args,int num_tensors,AtenTensorHandle * flatten_tensor_args)357 void OSSProxyExecutor::call_function(
358 int extern_node_index,
359 int num_ints,
360 int64_t* flatten_int_args,
361 int num_tensors,
362 AtenTensorHandle* flatten_tensor_args) {
363 TORCH_CHECK(
364 extern_node_index < static_cast<int>(op_kernels_.size()),
365 "Invalid extern node index");
366 OSSOpKernel& op_kernel = op_kernels_[extern_node_index];
367
368 std::vector<c10::IValue> stack = op_kernel.stack_;
369 auto& dynamic_args = op_kernel.dynamic_args_;
370
371 int tensor_id = 0;
372 int int_id = 0;
373 for (auto& dynamic_arg : dynamic_args) {
374 int arg_index = dynamic_arg.arg_index;
375 DynamicArgType dynamic_arg_type = dynamic_arg.arg_type;
376 int length = dynamic_arg.length;
377
378 if (length == 0) {
379 continue;
380 }
381
382 switch (dynamic_arg_type) {
383 case DynamicArgType::TensorType: {
384 at::Tensor* tensor =
385 tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
386 stack[arg_index] = *tensor;
387 break;
388 }
389 case DynamicArgType::IntType: {
390 int64_t val = flatten_int_args[int_id++];
391 stack[arg_index] = val;
392 break;
393 }
394 case DynamicArgType::ListTensorType: {
395 std::vector<at::Tensor> tensor_list;
396 for (int j = 0; j < length; j++) {
397 at::Tensor* tensor =
398 tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
399 tensor_list.push_back(*tensor);
400 }
401 stack[arg_index] = tensor_list;
402 break;
403 }
404 case DynamicArgType::ListOptionalTensorType: {
405 std::vector<std::optional<at::Tensor>> optional_tensor_list;
406 auto& list_item_types = dynamic_arg.list_item_types;
407 TORCH_CHECK(
408 list_item_types.has_value(),
409 "Could not find list of item types for optional tensor list input");
410
411 for (std::string item_type : list_item_types.value()) {
412 if (item_type == "as_tensor") {
413 at::Tensor* tensor = tensor_handle_to_tensor_pointer(
414 flatten_tensor_args[tensor_id++]);
415 optional_tensor_list.emplace_back(*tensor);
416 } else if (item_type == "as_none") {
417 optional_tensor_list.emplace_back(c10::nullopt);
418 }
419 }
420 stack[arg_index] = optional_tensor_list;
421 break;
422 }
423 case DynamicArgType::ListIntType: {
424 std::vector<int64_t> vals;
425 for (int j = 0; j < length; j++) {
426 vals.push_back(flatten_int_args[int_id++]);
427 }
428 stack[arg_index] = vals;
429 break;
430 }
431 default:
432 TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type);
433 }
434 }
435
436 int num_output_tensors = op_kernel.num_output_tensors();
437 TORCH_CHECK(
438 tensor_id == num_tensors - num_output_tensors,
439 "Mismatch between tensors consumed and num of input tensor, got tensor_id = .",
440 tensor_id,
441 ", expected num = ",
442 num_tensors - num_output_tensors);
443 TORCH_CHECK(
444 int_id == num_ints,
445 "Mismatch between ints consumed and num_ints, got int_id = ",
446 int_id,
447 ", num_ints = ",
448 num_ints);
449
450 // Call the op with the prepared stack.
451 const c10::OperatorHandle& op = op_kernel.op_handle_;
452 op.callBoxed(stack);
453
454 const c10::FunctionSchema& schema = op.schema();
455 const auto& schema_returns = schema.returns();
456
457 TORCH_CHECK(op_kernel.outputs_.size() == stack.size());
458 // TODO: what about optional outputs? This assert may not hold
459 TORCH_CHECK(stack.size() == schema_returns.size());
460
461 int index = 0;
462 for (const auto& schema_return : schema_returns) {
463 if (schema_return.type()->kind() == c10::TypeKind::TensorType) {
464 at::Tensor* tensor =
465 tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
466 *tensor = stack[index++].toTensor();
467 } else if (
468 schema_return.type()->kind() == c10::TypeKind::ListType &&
469 schema_return.type()->isSubtypeOf(at::ListType::ofTensors())) {
470 auto tensors = stack[index++].toTensorList();
471 for (size_t i = 0; i < tensors.size(); ++i) {
472 at::Tensor* tensor =
473 tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
474 *tensor = tensors[i];
475 }
476 } else {
477 TORCH_CHECK(
478 false,
479 "NYI: Unsupported return type for schema: ",
480 schema_return.type()->repr_str());
481 }
482 }
483
484 TORCH_CHECK(
485 tensor_id == num_tensors,
486 "Mismatch between tensors consumed and num_tensors, got tensor_id = ",
487 tensor_id,
488 ", expected num = ",
489 num_tensors);
490 }
491
492 } // namespace torch::aot_inductor
493