1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
3
4 #include <ATen/ATen.h>
5
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <torch/csrc/Dtype.h>
8 #include <torch/csrc/Layout.h>
9 #include <torch/csrc/MemoryFormat.h>
10 #include <torch/csrc/PyInterpreter.h>
11 #include <torch/csrc/autograd/python_variable.h>
12 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
13 #ifdef USE_CUDA
14 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
15 #endif
16 #include <torch/csrc/jit/frontend/function_schema_parser.h>
17
18 #include <ATen/core/jit_type.h>
19 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
20 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
21
22 namespace torch::inductor {
23
24 namespace {
25
unpack_tensor_ivalue(const c10::IValue & ivalue,const c10::Device & device,std::vector<at::Tensor> & inputs)26 inline void unpack_tensor_ivalue(
27 const c10::IValue& ivalue,
28 const c10::Device& device,
29 std::vector<at::Tensor>& inputs) {
30 inputs.push_back(ivalue.toTensor());
31 }
32
unpack_optional_tensor_ivalue(const c10::IValue & ivalue,const c10::Device & device,std::vector<at::Tensor> & inputs)33 inline void unpack_optional_tensor_ivalue(
34 const c10::IValue& ivalue,
35 const c10::Device& device,
36 std::vector<at::Tensor>& inputs) {
37 auto ivalue_opt_tensor = ivalue.toOptional<at::Tensor>();
38 if (ivalue_opt_tensor.has_value()) {
39 inputs.push_back(ivalue_opt_tensor.value());
40 }
41 }
42
unpack_tensor_list_ivalue(const c10::IValue & ivalue,const c10::Device & device,std::vector<at::Tensor> & inputs)43 inline void unpack_tensor_list_ivalue(
44 const c10::IValue& ivalue,
45 const c10::Device& device,
46 std::vector<at::Tensor>& inputs) {
47 for (const auto& item : ivalue.toListRef()) {
48 inputs.push_back(item.toTensor());
49 }
50 }
51
unpack_optional_tensor_list_ivalue(const c10::IValue & ivalue,const c10::Device & device,std::vector<at::Tensor> & inputs)52 inline void unpack_optional_tensor_list_ivalue(
53 const c10::IValue& ivalue,
54 const c10::Device& device,
55 std::vector<at::Tensor>& inputs) {
56 for (const auto& item : ivalue.toListRef()) {
57 unpack_optional_tensor_ivalue(item, device, inputs);
58 }
59 }
60
unpack_tensors(const std::vector<c10::Argument> & arguments,const torch::jit::Stack & stack,const c10::Device & device)61 std::vector<at::Tensor> unpack_tensors(
62 const std::vector<c10::Argument>& arguments,
63 const torch::jit::Stack& stack,
64 const c10::Device& device) {
65 std::vector<at::Tensor> inputs;
66 for (size_t idx = 0; idx < stack.size(); idx++) {
67 const auto& ivalue = stack[idx];
68 const auto& ivalue_arg = arguments[idx];
69 if (ivalue.isTensor()) {
70 unpack_tensor_ivalue(ivalue, device, inputs);
71 } else if (ivalue.isTensorList()) {
72 unpack_tensor_list_ivalue(ivalue, device, inputs);
73 } else if (ivalue.isOptionalTensorList()) {
74 unpack_optional_tensor_list_ivalue(ivalue, device, inputs);
75 } else if (
76 *ivalue_arg.real_type() ==
77 *c10::getTypePtr<std::optional<at::Tensor>>()) {
78 // ivalue is std::optional<at::Tensor>
79 unpack_optional_tensor_ivalue(ivalue, device, inputs);
80 }
81 }
82 return inputs;
83 }
84
85 // Find the first positional argument that isn't defaulted
is_default_value(const c10::Argument & argument,const c10::IValue & ivalue)86 bool is_default_value(
87 const c10::Argument& argument,
88 const c10::IValue& ivalue) {
89 if (!argument.default_value().has_value()) {
90 return false;
91 }
92 const auto& default_ivalue = *argument.default_value();
93 if (default_ivalue != ivalue) {
94 return false;
95 }
96
97 return true;
98 }
99
unpack_input_parameters(const std::vector<c10::Argument> & arguments,const torch::jit::Stack & stack)100 std::vector<ParameterMetadata> unpack_input_parameters(
101 const std::vector<c10::Argument>& arguments,
102 const torch::jit::Stack& stack) {
103 std::vector<ParameterMetadata> inputs_metadata;
104 // Represent the order of argument and skip default parameter
105 int64_t arg_order = 0;
106 for (size_t idx = 0; idx < stack.size(); idx++) {
107 // By default, the parameter will not be cached if its value is the default
108 // value.
109 // - produce_aoti_kernel_lib utilizes parseIValuesToPyArgsKwargs to get
110 // args and kwargs.
111 // - parseIValuesToPyArgsKwargs skips the parameter if its value is the
112 // default value.
113 if (is_default_value(arguments[idx], stack[idx])) {
114 continue;
115 }
116
117 if (stack[idx].isScalar()) {
118 // Beyond c10::Scalar, the floating value and interger value are also
119 // represented as Scalar.
120 inputs_metadata.emplace_back(stack[idx].toScalar(), arg_order);
121 } else if (stack[idx].isTensorList()) {
122 // tensor list
123 inputs_metadata.emplace_back(stack[idx].toTensorList().vec(), arg_order);
124 } else if (stack[idx].isOptionalTensorList()) {
125 // optional tensor list: std::vector<std::optional<at::Tensor>>
126 std::vector<at::Tensor> tensor_list;
127 for (const auto& item : stack[idx].toListRef()) {
128 if (item.toOptional<at::Tensor>().has_value()) {
129 tensor_list.push_back(item.toOptional<at::Tensor>().value());
130 }
131 }
132 inputs_metadata.emplace_back(tensor_list, arg_order);
133 } else if (
134 *arguments[idx].real_type() ==
135 *c10::getTypePtr<std::optional<at::Tensor>>()) {
136 // optional tensor
137 if (stack[idx].toOptional<at::Tensor>().has_value()) {
138 inputs_metadata.emplace_back(
139 stack[idx].toOptional<at::Tensor>().value(), arg_order);
140 }
141 } else if (stack[idx].isTensor()) {
142 inputs_metadata.emplace_back(stack[idx].toTensor(), arg_order);
143 } else if (stack[idx].isString()) {
144 inputs_metadata.emplace_back(stack[idx].toStringRef(), arg_order);
145 } else if (stack[idx].isBool()) {
146 inputs_metadata.emplace_back(c10::Scalar(stack[idx].toBool()), arg_order);
147 } else if (stack[idx].isDevice()) {
148 inputs_metadata.emplace_back(stack[idx].toDevice(), arg_order);
149 } else {
150 TORCH_CHECK_NOT_IMPLEMENTED(
151 false,
152 "Not implemented for operations that contain a parameter which is ",
153 "not one of the following types: at::Tensor, at::TensorList, ",
154 "std::optional<at::Tensor>, std::vector<std::optional<at::Tensor>> and c10::Scalar.",
155 "The input type is ",
156 stack[idx].type()->str());
157 }
158
159 arg_order++;
160 }
161
162 return inputs_metadata;
163 }
164
165 } // namespace
166
AOTIPythonKernelHolder(c10::DispatchKey dispatch_key,c10::string_view ns,c10::string_view op_name_with_overload)167 AOTIPythonKernelHolder::AOTIPythonKernelHolder(
168 c10::DispatchKey dispatch_key,
169 c10::string_view ns,
170 c10::string_view op_name_with_overload)
171 : dispatch_key_(dispatch_key),
172 ns_(std::string(ns)),
173 op_name_with_overload_(std::string(op_name_with_overload)),
174 device_(c10::dispatchKeyToDeviceType(dispatch_key_), 0),
175 pyinterpreter_(getPyInterpreter()) {
176 auto device_name = c10::DeviceTypeName(device_.type());
177 auto registered_aoti_runner = getAOTIModelRunnerRegistry();
178 TORCH_CHECK(
179 device_.type() == c10::DeviceType::CUDA ||
180 device_.type() == c10::DeviceType::CPU ||
181 registered_aoti_runner.find(device_name) !=
182 registered_aoti_runner.end(),
183 "AOTI for eager does not support ",
184 c10::DeviceTypeName(device_.type()),
185 " now.");
186
187 init_aoti_kernel_cache();
188 }
189
operator ()(const c10::OperatorHandle & op,c10::DispatchKeySet keyset,torch::jit::Stack * stack)190 void AOTIPythonKernelHolder::operator()(
191 const c10::OperatorHandle& op,
192 c10::DispatchKeySet keyset,
193 torch::jit::Stack* stack) {
194 AOTIKernelMetadata aoti_kernel_metadata;
195 if (cache_lookup(op, keyset, stack, aoti_kernel_metadata)) {
196 cache_hit(aoti_kernel_metadata, op, keyset, stack);
197 } else {
198 cache_miss(op, keyset, stack);
199 }
200 }
201
cache_lookup(const c10::OperatorHandle & op,const c10::DispatchKeySet & keyset,const torch::jit::Stack * stack,AOTIKernelMetadata & aoti_kernel_metadata)202 bool AOTIPythonKernelHolder::cache_lookup(
203 const c10::OperatorHandle& op,
204 const c10::DispatchKeySet& keyset,
205 const torch::jit::Stack* stack,
206 AOTIKernelMetadata& aoti_kernel_metadata) {
207 TORCH_CHECK_NOT_IMPLEMENTED(
208 op.schema().returns().size() == 1,
209 "Not implemented for operations that return either multiple values or no value.");
210 TORCH_CHECK_NOT_IMPLEMENTED(
211 op.schema().returns()[0].type()->isSubtypeOf(c10::TensorType::get()),
212 "Not implemented for operations that return a non-Tensor value.");
213
214 auto inputs_metadata =
215 unpack_input_parameters(op.schema().arguments(), *stack);
216 for (const auto& aoti_kernel_cache : aoti_kernel_cache_) {
217 if (aoti_kernel_cache.check(inputs_metadata)) {
218 aoti_kernel_metadata = aoti_kernel_cache;
219 return true;
220 }
221 }
222
223 return false;
224 }
225
cache_hit(const AOTIKernelMetadata & aoti_kernel_metadata,const c10::OperatorHandle & op,const c10::DispatchKeySet & keyset,torch::jit::Stack * stack)226 void AOTIPythonKernelHolder::cache_hit(
227 const AOTIKernelMetadata& aoti_kernel_metadata,
228 const c10::OperatorHandle& op,
229 const c10::DispatchKeySet& keyset,
230 torch::jit::Stack* stack) {
231 auto inputs = unpack_tensors(op.schema().arguments(), *stack, device_);
232 torch::jit::drop(*stack, op.schema().arguments().size());
233
234 auto outputs = aoti_kernel_metadata.kernel_runner_->run(inputs);
235 for (auto& output : outputs) {
236 stack->emplace_back(output);
237 }
238 }
239
init_aoti_kernel_cache()240 void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
241 if (device_.type() == c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES) {
242 return;
243 }
244
245 py::gil_scoped_acquire gil;
246
247 py::handle load_aoti_eager_cache_function =
248 py::module::import("torch._inductor.aoti_eager")
249 .attr("load_aoti_eager_cache");
250 TORCH_INTERNAL_ASSERT(
251 load_aoti_eager_cache_function.ptr() != nullptr,
252 "Failed to import - torch._inductor.aoti_eager.load_aoti_eager_cache");
253
254 auto result = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
255 load_aoti_eager_cache_function.ptr(),
256 py::str(ns_).ptr(),
257 py::str(op_name_with_overload_).ptr(),
258 py::str(c10::DeviceTypeName(device_.type(), true)).ptr(),
259 nullptr));
260 TORCH_INTERNAL_ASSERT(
261 result.ptr() != nullptr && result.ptr() != Py_None,
262 "Failed to load AOTI kernel. Operator Name is ",
263 op_name_with_overload_);
264
265 auto build_tensor_metadata = [](const py::dict& metadata) -> TensorMetadata {
266 // Access the fields of each metadata dict
267 auto is_dynamic = metadata["is_dynamic"].cast<bool>();
268 auto device_type = metadata["device_type"].cast<std::string>();
269 auto device_index = metadata["device_index"].cast<int8_t>();
270 auto data_type_obj = metadata["dtype"].cast<py::object>();
271 TORCH_INTERNAL_ASSERT(THPDtype_Check(data_type_obj.ptr()));
272 auto data_type =
273 reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
274 auto sizes = metadata["sizes"].cast<std::vector<int64_t>>();
275 auto strides = metadata["strides"].cast<std::vector<int64_t>>();
276 auto requires_grad = metadata["requires_grad"].cast<bool>();
277 auto dispatch_key_set_raw_repr =
278 metadata["dispatch_key_set"].cast<uint64_t>();
279 auto dispatch_key_set = c10::DispatchKeySet(
280 c10::DispatchKeySet::RAW, dispatch_key_set_raw_repr);
281 auto device = c10::Device(device_type);
282 device.set_index(device_index);
283
284 auto tensor_metadata = TensorMetadata(
285 is_dynamic,
286 data_type,
287 device,
288 dispatch_key_set,
289 sizes,
290 strides,
291 requires_grad);
292
293 // Build guard for tensor check
294 torch::dynamo::LocalState state;
295 state.overrideDispatchKeySet(dispatch_key_set);
296 tensor_metadata.build_guard(state);
297
298 return tensor_metadata;
299 };
300
301 TORCH_INTERNAL_ASSERT(py::isinstance<py::list>(result));
302 auto kernel_info_list = result.cast<py::list>();
303 for (auto kernel_info : kernel_info_list) {
304 TORCH_INTERNAL_ASSERT(py::isinstance<py::dict>(kernel_info));
305 auto item_dict = kernel_info.cast<py::dict>();
306
307 // Access the kernel_path field
308 auto kernel_path = item_dict["kernel_path"].cast<std::string>();
309
310 // Access the meta_info list
311 auto inputs_metadata = item_dict["meta_info"].cast<py::list>();
312
313 std::vector<ParameterMetadata> parameter_metadata_list;
314 // Loop over the meta_info list
315 for (auto item_metadata : inputs_metadata) {
316 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(py::isinstance<py::dict>(item_metadata));
317 auto metadata = item_metadata.cast<py::dict>();
318 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(metadata.contains("arg_order"));
319 uint64_t arg_idx = metadata["arg_order"].cast<uint64_t>();
320 bool is_scalar = metadata.contains("scalar_value");
321 bool is_tensor_list = metadata.contains("tensor_list");
322 bool is_string = metadata.contains("string_value");
323 bool is_device = metadata.contains("device_type_value");
324 bool is_dtype = metadata.contains("dtype_value");
325 bool is_layout = metadata.contains("layout_value");
326
327 if (is_tensor_list) {
328 // Tensor List
329 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
330 py::isinstance<py::list>(metadata["tensor_list"]));
331 auto tensor_list = metadata["tensor_list"].cast<py::list>();
332 std::vector<TensorMetadata> test_list_metadata;
333 for (auto item_tensor : tensor_list) {
334 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
335 py::isinstance<py::dict>(item_tensor));
336 auto metadata = item_tensor.cast<py::dict>();
337 auto tensor_metadata = build_tensor_metadata(metadata);
338 test_list_metadata.push_back(tensor_metadata);
339 }
340 parameter_metadata_list.emplace_back(test_list_metadata, arg_idx);
341 } else if (is_scalar) {
342 // Scalar
343 auto metadata = item_metadata.cast<py::dict>();
344 auto dtype_obj = metadata["dtype"].cast<py::object>();
345 TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj.ptr()));
346 auto dtype_value =
347 reinterpret_cast<THPDtype*>(dtype_obj.ptr())->scalar_type;
348
349 c10::Scalar scalar;
350 if (c10::isFloatingType(dtype_value)) {
351 scalar = metadata["scalar_value"].cast<double>();
352 } else if (c10::isIntegralType(dtype_value, false)) {
353 scalar = metadata["scalar_value"].cast<int64_t>();
354 } else if (dtype_value == c10::kBool) {
355 scalar = metadata["scalar_value"].cast<bool>();
356 } else {
357 TORCH_CHECK_NOT_IMPLEMENTED(
358 false,
359 "Not implemented for operations that contain a scalar parameter which is ",
360 dtype_value);
361 }
362
363 parameter_metadata_list.emplace_back(c10::Scalar(scalar), arg_idx);
364 } else if (is_string) {
365 // String
366 auto metadata = item_metadata.cast<py::dict>();
367 auto str_value = metadata["string_value"].cast<std::string>();
368 parameter_metadata_list.emplace_back(str_value, arg_idx);
369 } else if (is_dtype) {
370 // Dtype
371 auto metadata = item_metadata.cast<py::dict>();
372 auto dtype_value_obj = metadata["dtype_value"].cast<py::object>();
373 TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_value_obj.ptr()));
374 auto dtype_value =
375 reinterpret_cast<THPDtype*>(dtype_value_obj.ptr())->scalar_type;
376 parameter_metadata_list.emplace_back(
377 c10::Scalar(static_cast<int>(dtype_value)), arg_idx);
378 } else if (is_device) {
379 // Device
380 auto metadata = item_metadata.cast<py::dict>();
381 auto device_type_value =
382 metadata["device_type_value"].cast<std::string>();
383 auto device = c10::Device(device_type_value);
384 if (metadata["device_index_value"].ptr() != Py_None) {
385 auto device_index_value =
386 metadata["device_index_value"].cast<c10::DeviceIndex>();
387 device.set_index(device_index_value);
388 }
389 parameter_metadata_list.emplace_back(device, arg_idx);
390 } else if (is_layout) {
391 auto metadata = item_metadata.cast<py::dict>();
392 auto layout_value_obj = metadata["layout_value"].cast<py::object>();
393 TORCH_INTERNAL_ASSERT(THPLayout_Check(layout_value_obj.ptr()));
394 auto layout_value =
395 reinterpret_cast<THPLayout*>(layout_value_obj.ptr())->layout;
396 parameter_metadata_list.emplace_back(
397 c10::Scalar(static_cast<int>(layout_value)), arg_idx);
398 } else {
399 // Tensor
400 auto metadata = item_metadata.cast<py::dict>();
401 auto tensor_metadata = build_tensor_metadata(metadata);
402 parameter_metadata_list.emplace_back(tensor_metadata, arg_idx);
403 }
404 }
405
406 AOTIKernelMetadata aoti_kernel_metadata;
407 aoti_kernel_metadata.parameter_metadata_list_ =
408 std::move(parameter_metadata_list);
409 aoti_kernel_metadata.kernel_runner_ = load_aoti_model_runner(kernel_path);
410 aoti_kernel_cache_.push_back(aoti_kernel_metadata);
411 }
412 }
413
414 std::shared_ptr<AOTIModelContainerRunner> AOTIPythonKernelHolder::
load_aoti_model_runner(const std::string & so_path)415 load_aoti_model_runner(const std::string& so_path) {
416 auto device_name = c10::DeviceTypeName(device_.type());
417 auto registered_aoti_runner = getAOTIModelRunnerRegistry();
418 TORCH_CHECK(
419 device_.type() == c10::DeviceType::CUDA ||
420 device_.type() == c10::DeviceType::CPU ||
421 registered_aoti_runner.find(device_name) !=
422 registered_aoti_runner.end(),
423 "AOTI for eager does not support ",
424 c10::DeviceTypeName(device_.type()),
425 " now.");
426 if (device_.type() == c10::DeviceType::CUDA) {
427 #ifdef USE_CUDA
428 return std::make_shared<AOTIModelContainerRunnerCuda>(so_path);
429 #else
430 return nullptr;
431 #endif
432 } else if (device_.type() == c10::DeviceType::CPU) {
433 return std::make_shared<AOTIModelContainerRunnerCpu>(so_path);
434 } else {
435 auto aoti_model_runer_fn = registered_aoti_runner[device_name];
436 return aoti_model_runer_fn(so_path, 1, device_name, "");
437 }
438 }
439
cache_miss(const c10::OperatorHandle & op,const c10::DispatchKeySet & keyset,torch::jit::Stack * stack)440 void AOTIPythonKernelHolder::cache_miss(
441 const c10::OperatorHandle& op,
442 const c10::DispatchKeySet& keyset,
443 torch::jit::Stack* stack) {
444 auto kernel_lib_path = produce_aoti_kernel_lib(op, keyset, stack);
445 std::shared_ptr<AOTIModelContainerRunner> kernel = nullptr;
446 kernel = load_aoti_model_runner(kernel_lib_path);
447 TORCH_INTERNAL_ASSERT(
448 kernel != nullptr,
449 "Unsupported device: ",
450 c10::DeviceTypeName(device_.type()));
451 auto inputs = unpack_tensors(op.schema().arguments(), *stack, device_);
452 auto outputs = kernel->run(inputs);
453 torch::jit::drop(*stack, op.schema().arguments().size());
454 // TODO: Get the output type of this operation and then convert to the
455 // output type.
456 for (auto& output : outputs) {
457 torch::jit::push(*stack, std::move(output));
458 }
459 }
460
produce_aoti_kernel_lib(const c10::OperatorHandle & op,const c10::DispatchKeySet & keyset,const torch::jit::Stack * stack)461 std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib(
462 const c10::OperatorHandle& op,
463 const c10::DispatchKeySet& keyset,
464 const torch::jit::Stack* stack) {
465 auto arguments = torch::jit::last(*stack, op.schema().arguments().size());
466
467 const auto& schema = op.schema();
468 const auto& qualified_name = op.operator_name().name;
469 const auto& overload_name =
470 schema.overload_name().empty() ? "default" : schema.overload_name();
471 auto pos = qualified_name.find("::");
472 TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
473 std::string ns_str(
474 qualified_name.begin(),
475 qualified_name.begin() + static_cast<ptrdiff_t>(pos));
476 std::string func_name(
477 qualified_name.begin() + static_cast<ptrdiff_t>(pos + strlen("::")),
478 qualified_name.end());
479
480 py::gil_scoped_acquire gil;
481 py::handle op_py_func = op.getPythonOp(pyinterpreter_, [&]() -> PyObject* {
482 py::handle torch_api_function = py::module::import("torch")
483 .attr("ops")
484 .attr(ns_str.c_str())
485 .attr(func_name.c_str());
486 if (overload_name.empty()) {
487 return torch_api_function.attr("default").ptr();
488 } else {
489 return torch_api_function.attr(overload_name.c_str()).ptr();
490 }
491 });
492
493 TORCH_INTERNAL_ASSERT(
494 op_py_func.ptr() != nullptr && op_py_func.ptr() != Py_None,
495 "Failed to get python operation. Operator Name is ",
496 op.operator_name().name,
497 ", Overload Name is ",
498 overload_name);
499
500 py::handle aot_compile_function =
501 py::module::import("torch._inductor.aoti_eager")
502 .attr("aoti_compile_with_persistent_cache");
503 TORCH_INTERNAL_ASSERT(
504 aot_compile_function.ptr() != nullptr &&
505 aot_compile_function.ptr() != Py_None,
506 "Failed to import - torch._inductor.aoti_eager.aoti_compile_with_persistent_cache");
507
508 // Pass the python operation to the AOT Inductor to generate the kernel
509 // library.
510 auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments.vec());
511 auto result = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
512 aot_compile_function.ptr(),
513 py::str(ns_str).ptr(),
514 py::str(op_name_with_overload_).ptr(),
515 py::str(c10::DeviceTypeName(device_.type(), true)).ptr(),
516 py::bool_(false).ptr(),
517 op_py_func.ptr(),
518 args_kwargs.first.ptr(),
519 args_kwargs.second.ptr(),
520 nullptr));
521 TORCH_INTERNAL_ASSERT(result.ptr() != nullptr && result.ptr() != Py_None);
522
523 auto kernel_lib_path = py::cast<std::string>(result);
524 TORCH_CHECK(
525 !kernel_lib_path.empty(),
526 "Failed to produce kernel libarary by using AOTI for ",
527 c10::DeviceTypeName(device_.type()),
528 ". Operator Name is ",
529 op.operator_name().name,
530 ", Overload Name is ",
531 op.schema().overload_name());
532
533 return kernel_lib_path;
534 }
535
536 } // namespace torch::inductor
537 #endif
538