xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/pybind_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <ATen/core/jit_type.h>
5 #include <ATen/core/qualified_name.h>
6 #include <ATen/core/stack.h>
7 #include <pybind11/complex.h>
8 #include <pybind11/pybind11.h>
9 #include <pybind11/pytypes.h>
10 #include <torch/csrc/Device.h>
11 #include <torch/csrc/Dtype.h>
12 #include <torch/csrc/Export.h>
13 #include <torch/csrc/Layout.h>
14 #include <torch/csrc/QScheme.h>
15 #include <torch/csrc/Stream.h>
16 #include <torch/csrc/jit/api/module.h>
17 #include <torch/csrc/jit/frontend/schema_matching.h>
18 #include <torch/csrc/jit/frontend/tracer.h>
19 #include <torch/csrc/jit/python/module_python.h>
20 #include <torch/csrc/jit/python/python_custom_class.h>
21 #include <torch/csrc/jit/python/python_tracer.h>
22 #include <torch/csrc/jit/resource_guard.h>
23 #include <torch/csrc/jit/runtime/operator.h>
24 #include <torch/csrc/utils/pybind.h>
25 #include <torch/csrc/utils/python_arg_parser.h>
26 #include <torch/csrc/utils/six.h>
27 #ifdef USE_DISTRIBUTED
28 #include <torch/csrc/distributed/rpc/py_rref.h>
29 #include <torch/csrc/distributed/rpc/rref_impl.h>
30 #endif
31 
32 #include <ATen/core/function_schema.h>
33 #include <c10/core/Stream.h>
34 #include <c10/util/Exception.h>
35 #include <c10/util/irange.h>
36 #include <optional>
37 
38 #include <algorithm>
39 #include <cstddef>
40 #include <string>
41 #include <utility>
42 #include <vector>
43 
44 // The visibility attribute is to avoid a warning about storing a field in the
45 // struct that has a different visibility (from pybind) than the struct.
46 #ifdef _WIN32
47 #define VISIBILITY_HIDDEN
48 #else
49 #define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
50 #endif
51 
52 namespace torch::jit {
53 
54 using ResolutionCallback = std::function<py::object(std::string)>;
55 
56 void clear_registered_instances(void* ptr);
57 
58 TORCH_PYTHON_API IValue toIValue(
59     py::handle obj,
60     const TypePtr& type,
61     std::optional<int32_t> N = std::nullopt);
62 
63 TORCH_PYTHON_API py::object toPyObject(IValue ivalue);
64 
65 // Hack to overload the behavior of toIValue to accept Python
66 // numbers in places where a Tensor is expected
67 // See also torch::should_allow_numbers_as_tensors
68 class ToIValueAllowNumbersAsTensors {
69   bool old_;
70 
71  public:
72   ToIValueAllowNumbersAsTensors(bool enable);
73   ~ToIValueAllowNumbersAsTensors();
74 };
75 
76 // Wrap Python function to guard deref
77 // NB: Need VISIBILITY_HIDDEN for silencing compiler error,
78 // 'torch::jit::PythonFunctionGuard' declared with greater visibility than the
79 // type of its field 'torch::jit::PythonFunctionGuard::func_'
80 struct VISIBILITY_HIDDEN PythonFunctionGuard {
PythonFunctionGuardPythonFunctionGuard81   explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
82 
~PythonFunctionGuardPythonFunctionGuard83   ~PythonFunctionGuard() {
84     pybind11::gil_scoped_acquire ag;
85     func_.dec_ref();
86     // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
87     // decref on the PyObject again.
88     // See Note [Destructing py::object] in python_ivalue.h
89     func_.ptr() = nullptr;
90   }
91 
92   py::function func_;
93 };
94 
95 // The PythonFutureWrapper for ivalue::Future
96 //
97 // NB: VISIBILITY_HIDDEN is for silencing compiling error,
98 // "error: 'torch::jit::PythonFutureWrapper' declared with greater visibility
99 // than the type of its field 'torch::jit::PythonFutureWrapper::unwrap_func'
100 // [-Werror=attributes]"
101 //
102 // NB: inherit from enable_shared_from_this because then(py::function) needs to
103 //     get a shared_ptr from this pointer.
104 struct VISIBILITY_HIDDEN PythonFutureWrapper
105     : std::enable_shared_from_this<PythonFutureWrapper> {
106   using UnwrapFunc = std::function<void(py::object)>;
107 
108   explicit PythonFutureWrapper(
109       c10::intrusive_ptr<c10::ivalue::Future> fut,
110       std::optional<UnwrapFunc> unwrap_func = std::nullopt)
futPythonFutureWrapper111       : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}
112 
113   explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
114   PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
115 
donePythonFutureWrapper116   bool done() {
117     return fut->completed();
118   }
119 
valuePythonFutureWrapper120   py::object value() {
121     // acquiring GIL as toPyObject creates new py::object
122     // without grabbing the GIL.
123     py::gil_scoped_acquire acquire;
124     py::object py_obj = toPyObject(fut->value());
125     // unwrap_func is a general compositional function that takes in a
126     // py::object and executes some python function. It is currently mostly used
127     // to throw python exceptions.
128     if (unwrap_func) {
129       (*unwrap_func)(py_obj);
130     }
131     return py_obj;
132   }
133 
waitPythonFutureWrapper134   py::object wait() {
135     fut->wait();
136     if (jit::tracer::isTracing()) {
137       auto graph = jit::tracer::getTracingState()->graph;
138 
139       Value* fut_val = jit::tracer::getValueTrace(fut);
140       auto output = graph->insert(aten::wait, {fut_val});
141       jit::tracer::setValueTrace(fut->value(), output);
142     }
143     return value();
144   }
145 
146   // The py::function cb arg must take a std::shared_ptr<PythonFutureWrapper>
147   // (i.e., torch._C.Future) as the only argument. If the type mismatches, an
148   // error will be thrown when waiting for the value of this returned Future.
thenPythonFutureWrapper149   std::shared_ptr<PythonFutureWrapper> then(py::function cb) {
150     // We need this an additional layer of wrapper here to guard the
151     // destruction of the py::function object. Because, the
152     // Future owns a reference to the py::function in its callback
153     // vector, but Future does not acquire GIL on destruction.
154     auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
155 
156     return std::make_shared<jit::PythonFutureWrapper>(fut->then(
157         // Capture a copy of the ivalue::Future instead of the `this` pointer
158         // because the PythonFutureWrapper object could have been deleted
159         // when the callbacks are fired. For example, RPC only captures the
160         // ivalue::Future instead of PythonFutureWrapper in JitFuture's
161         // callback functions. Hence, if user code does not hold a reference to
162         // this PythonFutureWrapper object, there is no guarantee that the
163         // PythonFutureWrapper is still valid when running the callback.
164         [pyFut(this->getPtr()),
165          pf(std::move(pf))](c10::ivalue::Future& /* unused */) -> IValue {
166           try {
167             pybind11::gil_scoped_acquire ag;
168             return toIValue(pf->func_(pyFut), PyObjectType::get());
169           } catch (py::error_already_set& e) {
170             auto err = std::runtime_error(c10::str(
171                 "Got the following error when running the callback: ",
172                 e.what()));
173             {
174               pybind11::gil_scoped_acquire ag;
175               // Release ownership on py::objects and also restore Python
176               // Error Indicator.
177               e.restore();
178               // Clear the Python Error Indicator as we has recorded the
179               // exception in the response message.
180               PyErr_Clear();
181             }
182 
183             throw std::runtime_error(err);
184           }
185         },
186         PyObjectType::get()));
187   }
188 
add_done_callbackPythonFutureWrapper189   void add_done_callback(py::function cb) {
190     auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
191     // NOLINTNEXTLINE(modernize-avoid-bind)
192     fut->addCallback(std::bind(
193         [pyFut(this->getPtr())](
194             const std::shared_ptr<PythonFunctionGuard>& pf) {
195           try {
196             pybind11::gil_scoped_acquire ag;
197             pf->func_(pyFut);
198           } catch (py::error_already_set& e) {
199             {
200               pybind11::gil_scoped_acquire ag;
201               // Release ownership on py::objects and also restore Python
202               // Error Indicator.
203               e.restore();
204               // Clear the Python Error Indicator as we has recorded the
205               // exception in the response message.
206               PyErr_Clear();
207             }
208             // Log and ignore exceptions raised through the callback
209             LOG(ERROR) << "Got the following error when running the callback: "
210                        << e.what();
211 
212           } catch (const std::exception& e) {
213             // Log and ignore exceptions raised through the callback
214             LOG(ERROR) << "Got the following error when running the callback: "
215                        << e.what();
216           }
217         },
218         std::move(pf)));
219   }
220 
markCompletedPythonFutureWrapper221   void markCompleted(const py::object& pyValue) {
222     DCHECK(PyGILState_Check());
223     IValue value = toIValue(pyValue, PyObjectType::get());
224 
225     py::gil_scoped_release release;
226     fut->markCompleted(std::move(value));
227   }
228 
229   c10::intrusive_ptr<c10::ivalue::Future> fut;
230   // unwrap_func works like a callback for the value returned by
231   // PythonFutureWrapper::wait().
232   std::optional<UnwrapFunc> unwrap_func;
233 
234  private:
getPtrPythonFutureWrapper235   std::shared_ptr<PythonFutureWrapper> getPtr() {
236     return shared_from_this();
237   }
238 };
239 
240 // The PythonAwaitWrapper for ivalue::Await
241 //
242 // Expresses delayed function execution with Lazy semantic.
243 // i.e. Await[W] in eager mode can be used as W.
244 // When the attribute of W type is requested, Await[W] will return the
245 // attribute of W, transparently calling wait() beforehand.
246 // No Lazy semantic for script, explicit wait(Await[W]) -> W must be called to
247 // convert to type W.
248 //
249 // The Await object takes shared ownership of specified function and the
250 // arguments. After first call for wait() it owns the result. Deliberately no
251 // type inference for eager mode.
252 struct VISIBILITY_HIDDEN PythonAwaitWrapper
253     : std::enable_shared_from_this<PythonAwaitWrapper> {
PythonAwaitWrapperPythonAwaitWrapper254   explicit PythonAwaitWrapper(c10::intrusive_ptr<c10::ivalue::Await> aw)
255       : aw_(std::move(aw)) {}
PythonAwaitWrapperPythonAwaitWrapper256   explicit PythonAwaitWrapper(py::handle input) {
257     args_ = py::tuple(1u);
258     args_[0] = input;
259     auto type = PyObjectType::get();
260     aw_ = c10::make_intrusive<c10::ivalue::Await>(type);
261     aw_->markCompleted(toIValue(input, type));
262   }
263 
PythonAwaitWrapperPythonAwaitWrapper264   explicit PythonAwaitWrapper(py::function pf, py::tuple args)
265       : args_(std::move(args)) {
266     pyfg_ = std::make_shared<torch::jit::PythonFunctionGuard>(std::move(pf));
267 
268     std::function<IValue()> f = [fg(pyfg_), &args(args_)]() {
269       pybind11::gil_scoped_acquire ag;
270       return toIValue(fg->func_(*args), PyObjectType::get());
271     };
272     aw_ = c10::make_intrusive<c10::ivalue::Await>(
273         PyObjectType::get(), std::move(f));
274   }
275 
276   explicit PythonAwaitWrapper(const PythonAwaitWrapper&) = delete;
277   PythonAwaitWrapper& operator=(const PythonAwaitWrapper&) = delete;
278 
waitPythonAwaitWrapper279   py::object wait() {
280     py::gil_scoped_acquire acquire;
281     return toPyObject(aw_->wait());
282   }
283 
284   // Nowait semantic means trivial case when Await is constructed from the
285   // result
is_nowaitPythonAwaitWrapper286   bool is_nowait() {
287     return pyfg_ == nullptr;
288   }
289 
fnPythonAwaitWrapper290   const py::function fn() {
291     TORCH_CHECK(
292         pyfg_, "Await constructed as awaitable_nowait does not have fn");
293     return pyfg_->func_;
294   }
295 
argsPythonAwaitWrapper296   const py::tuple args() {
297     return args_;
298   }
299 
typePythonAwaitWrapper300   TypePtr type() {
301     return aw_->type();
302   }
303 
304   c10::intrusive_ptr<c10::ivalue::Await> aw_;
305   std::shared_ptr<torch::jit::PythonFunctionGuard> pyfg_;
306   py::tuple args_;
307 
308  private:
getPtrPythonAwaitWrapper309   std::shared_ptr<PythonAwaitWrapper> getPtr() {
310     return shared_from_this();
311   }
312 };
313 
314 // error reporting: when reporting user-caused errors, these functions should
315 // not use AT_ERROR macros, since these macros add stack trace information
316 // that is confusing to display to the end user since it always reports
317 // locations in libtorch code rather than user code.
318 
get_python_cu()319 inline std::shared_ptr<CompilationUnit> get_python_cu() {
320   return py::module::import("torch.jit._state")
321       .attr("_python_cu")
322       .cast<std::shared_ptr<CompilationUnit>>();
323 }
324 
325 struct TypedIValue : public std::pair<IValue, TypePtr> {
326   using pair::pair;
327 
ivalueTypedIValue328   IValue& ivalue() {
329     return this->first;
330   }
typeTypedIValue331   TypePtr& type() {
332     return this->second;
333   }
334 };
335 
toDictKeyIValue(py::handle key)336 inline TypedIValue toDictKeyIValue(py::handle key) {
337   if (py::isinstance<py::str>(key)) {
338     return TypedIValue(
339         ConstantString::create(py::cast<std::string>(key)), StringType::get());
340   } else if (py::isinstance<py::int_>(key)) {
341     return TypedIValue(py::cast<int64_t>(key), IntType::get());
342   } else if (py::isinstance<py::float_>(key)) {
343     return TypedIValue(py::cast<double>(key), FloatType::get());
344   } else {
345     AT_ERROR("Dictionary inputs may only have string, int, or float keys");
346   }
347 }
348 
unifyOrInitializeType(const TypePtr & accum,const TypePtr & unify)349 inline std::optional<TypePtr> unifyOrInitializeType(
350     const TypePtr& accum,
351     const TypePtr& unify) {
352   if (!accum) {
353     return unify;
354   }
355   return unifyTypes(accum, unify);
356 }
357 
358 using InferredType = c10::InferredType;
359 
360 InferredType tryToInferContainerType(py::handle input, bool primitiveTypeOnly);
361 
362 // Try to infer the type of a Python object
363 // The type cannot be inferred if:
364 //   input is an empty container (list, dict)
365 //   input is an list with element types that cannot be unified
366 //   input is an dict with key or value types that cannot be unified
tryToInferType(py::handle input)367 inline InferredType tryToInferType(py::handle input) {
368   // Try tensor types
369   if (THPVariable_Check(input.ptr())) {
370     return InferredType(TensorType::get());
371   }
372 
373   if (input.is_none()) {
374     return InferredType(NoneType::get());
375   }
376 
377   if (py::isinstance<StrongFunctionPtr>(input)) {
378     auto fn = py::cast<StrongFunctionPtr>(input).function_;
379     return InferredType(FunctionType::create(fn));
380   }
381 
382   // Try basic types first
383   if (py::isinstance<py::bool_>(input)) {
384     return InferredType(BoolType::get());
385     // NOLINTNEXTLINE(bugprone-branch-clone)
386   } else if (py::isinstance<py::int_>(input)) {
387     return InferredType(IntType::get());
388   } else if (py::isinstance<py::float_>(input)) {
389     return InferredType(FloatType::get());
390   } else if (PyComplex_CheckExact(input.ptr())) {
391     return InferredType(ComplexType::get());
392   } else if (py::isinstance<py::str>(input)) {
393     return InferredType(StringType::get());
394   } else if (THPLayout_Check(input.ptr())) {
395     return InferredType(IntType::get());
396   } else if (THPDevice_Check(input.ptr())) {
397     return InferredType(DeviceObjType::get());
398   } else if (THPGenerator_Check(input.ptr())) {
399     return InferredType(GeneratorType::get());
400   } else if (THPStream_Check(input.ptr())) {
401     return InferredType(StreamObjType::get());
402   } else if (THPDtype_Check(input.ptr())) {
403     return InferredType(IntType::get());
404   } else if (THPQScheme_Check(input.ptr())) {
405     return InferredType(IntType::get());
406   } else if (THPLayout_Check(input.ptr())) {
407     return InferredType(IntType::get());
408   }
409 
410   auto enum_type = py::module::import("enum").attr("Enum");
411   py::bool_ isEnumValue = py::isinstance(input, enum_type);
412   if (py::cast<bool>(isEnumValue)) {
413     auto enum_class = input.attr("__class__");
414     auto enum_type = py::cast<TypePtr>(
415         py::module::import("torch.jit.annotations")
416             .attr("try_ann_to_type")(enum_class, SourceRange()));
417     return InferredType(std::move(enum_type));
418   }
419 
420   py::bool_ isClass =
421       py::module::import("inspect").attr("isclass")(input.get_type());
422   if (py::cast<bool>(isClass)) {
423     // Assume that the class is compiled already or will compile. Invalidate
424     // this later if needed.
425     bool class_compiled = true;
426 
427     // Check if the type is already compiled.
428     py::object existing_ty = py::module::import("torch.jit._state")
429                                  .attr("_get_script_class")(input.get_type());
430 
431     if (existing_ty.is_none()) {
432       // If not, try to compile it.
433       py::bool_ can_compile = py::module::import("torch._jit_internal")
434                                   .attr("can_compile_class")(input.get_type());
435 
436       if (py::cast<bool>(can_compile)) {
437         // Try to compile the class. This is wrapped in a try-catch because
438         // compilation of class types can raise an Exception and in that case,
439         // we want to defer to other attempts at type inference below rather
440         // than fail compilation altogether.
441         try {
442           py::module::import("torch.jit._script")
443               .attr("_recursive_compile_class")(
444                   input.get_type(), SourceRange());
445         } catch (...) {
446           // Invalidate the assumption that the class compiled so that we don't
447           // look up and return its JIT type as the type for the input.
448           class_compiled = false;
449         }
450       }
451     }
452 
453     // If the class compiled successfully, look up the existing JIT type by
454     // qualified name and return it.
455     if (class_compiled) {
456       auto script_class = py::module::import("torch.jit._state")
457                               .attr("_get_script_class")(input.get_type());
458 
459       if (!script_class.is_none()) {
460         auto class_type = py::cast<ClassTypePtr>(script_class);
461 
462         if (class_type && !class_type->is_module()) {
463           return InferredType(std::move(class_type));
464         }
465       }
466     }
467   }
468 
469   if (py::isinstance<Object>(input)) {
470     auto object = py::cast<Object>(input);
471     return InferredType(object.type());
472 #ifdef USE_RPC
473   } else if (py::isinstance<torch::distributed::rpc::PyRRef>(input)) {
474     auto rref_ivalue = input.cast<torch::distributed::rpc::PyRRef>().toIValue();
475     return InferredType(rref_ivalue.type());
476 #endif
477   }
478 
479   auto await_type = py::module::import("torch._awaits").attr("_Await");
480   py::bool_ is_await = py::isinstance(input, await_type);
481   if (py::cast<bool>(is_await)) {
482     auto awptr = input.cast<std::shared_ptr<PythonAwaitWrapper>>();
483     return InferredType(AwaitType::create(awptr->aw_->elementType()));
484   }
485 
486   if (as_module(py::cast<py::object>(input))) {
487     return InferredType("Cannot infer type of ScriptModule");
488   }
489 
490   auto module_type = py::module::import("torch.nn").attr("Module");
491   py::bool_ is_module = py::isinstance(input, module_type);
492   if (py::cast<bool>(is_module)) {
493     return InferredType("Cannot infer concrete type of torch.nn.Module");
494   }
495 
496   // Try container types
497   return tryToInferContainerType(input, false);
498 }
499 
500 // This function is similar to tryToInferType, but it only tries to infer
501 // primitive types (int, float, bool, complex) or nested container of primitive
502 // types.
tryToInferPrimitiveType(py::handle input)503 inline InferredType tryToInferPrimitiveType(py::handle input) {
504   if (input.is_none()) {
505     return InferredType(NoneType::get());
506   }
507 
508   // Only primitive data type
509   if (py::isinstance<py::bool_>(input)) {
510     return InferredType(BoolType::get());
511     // NOLINTNEXTLINE(bugprone-branch-clone)
512   } else if (py::isinstance<py::int_>(input)) {
513     return InferredType(IntType::get());
514   } else if (py::isinstance<py::float_>(input)) {
515     return InferredType(FloatType::get());
516   } else if (PyComplex_CheckExact(input.ptr())) {
517     return InferredType(ComplexType::get());
518   }
519 
520   // Try container types
521   return tryToInferContainerType(input, true);
522 }
523 
524 inline InferredType tryToInferContainerType(
525     py::handle input,
526     bool primitiveTypeOnly = false) {
527   if (six::isTuple(input)) {
528     py::tuple tuple = py::cast<py::tuple>(input);
529     std::vector<TypePtr> element_types;
530     element_types.reserve(tuple.size());
531 
532     for (py::handle elem : tuple) {
533       auto type_match = primitiveTypeOnly ? tryToInferPrimitiveType(elem)
534                                           : tryToInferType(elem);
535       if (type_match.success()) {
536         element_types.push_back(type_match.type());
537       } else {
538         // Forward error message along
539         return type_match.reason();
540       }
541     }
542     return InferredType(TupleType::create(std::move(element_types)));
543   } else if (PyDict_Check(input.ptr())) {
544     // Check to make sure we can generate useful input/output types
545     auto dict = py::cast<py::dict>(input);
546     size_t len = py::len(dict);
547     if (!len) {
548       return InferredType("Dictionary inputs must have entries");
549     }
550 
551     TypePtr key_type = nullptr;
552     TypePtr value_type = nullptr;
553 
554     for (auto entry : dict) {
555       // Try to infer the key type and unify it with the existing one
556       auto entry_key_type_match = primitiveTypeOnly
557           ? tryToInferPrimitiveType(entry.first)
558           : tryToInferType(entry.first);
559       if (!entry_key_type_match.success()) {
560         return entry_key_type_match.reason();
561       }
562       auto unified_key =
563           unifyOrInitializeType(key_type, entry_key_type_match.type());
564       if (!unified_key) {
565         return InferredType(c10::str(
566             "Dictionary inputs to traced functions must have consistent type. Found ",
567             key_type->repr_str(),
568             " and ",
569             (entry_key_type_match.type())->repr_str()));
570       }
571 
572       // Try to infer the value type and unify it with the existing one
573       auto entry_value_type_match = primitiveTypeOnly
574           ? tryToInferPrimitiveType(entry.second)
575           : tryToInferType(entry.second);
576       if (!entry_value_type_match.success()) {
577         return entry_value_type_match.reason();
578       }
579       auto unified_value =
580           unifyOrInitializeType(value_type, entry_value_type_match.type());
581       if (!unified_value) {
582         return InferredType(c10::str(
583             "Dictionary inputs to traced functions must have consistent type. Found ",
584             value_type->repr_str(),
585             " and ",
586             (entry_value_type_match.type())->repr_str()));
587       }
588 
589       key_type = *unified_key;
590       value_type = *unified_value;
591     }
592     return InferredType(
593         DictType::create(std::move(key_type), std::move(value_type)));
594   } else if (PyList_Check(input.ptr())) {
595     auto list = py::cast<py::list>(input);
596     size_t len = py::len(list);
597     if (!len) {
598       return InferredType("List trace inputs must have elements");
599     }
600 
601     TypePtr element_type = nullptr;
602     for (auto elem : list) {
603       auto element_type_match = primitiveTypeOnly
604           ? tryToInferPrimitiveType(elem)
605           : tryToInferType(elem);
606       if (!element_type_match.success()) {
607         return InferredType(c10::str(
608             "Could not infer type of list element: ",
609             element_type_match.reason()));
610       }
611       auto unified_type =
612           unifyOrInitializeType(element_type, element_type_match.type());
613       if (!unified_type) {
614         return InferredType(c10::str(
615             "List inputs to traced functions must have consistent element type. Found ",
616             element_type->repr_str(),
617             " and ",
618             (element_type_match.type())->repr_str()));
619       }
620       element_type = *unified_type;
621     }
622     return InferredType(ListType::create(element_type));
623   } else {
624     if (primitiveTypeOnly) {
625       return InferredType(c10::str(
626           "Only tuple, list, or dict (possibly nested) of primitive types (bool, float, int, complex)",
627           "are supported ",
628           "as inputs or outputs of traced functions",
629           ", but instead got value of type ",
630           py::str(input.get_type().attr("__name__")),
631           "."));
632     } else {
633       // TODO: this message is not correct anymore, since this InferredType is
634       // used from a bunch of circumstances unrelated to tracing. We can re-use
635       // this instead of the attribute_failure stuff in concreteType
636       return InferredType(c10::str(
637           "Only tensors and (possibly nested) tuples of tensors, lists, or dicts",
638           "are supported ",
639           "as inputs or outputs of traced functions",
640           ", but instead got value of type ",
641           py::str(input.get_type().attr("__name__")),
642           "."));
643     }
644   }
645 }
646 
isTraceableType(const TypePtr & type)647 inline bool isTraceableType(const TypePtr& type) {
648   if (type->isSubtypeOf(*TensorType::get())) {
649     return true;
650   }
651 
652   if (auto list_type = type->cast<ListType>()) {
653     return isTraceableType(list_type->getElementType());
654   }
655 
656   if (auto tuple_type = type->cast<TupleType>()) {
657     return std::all_of(
658         tuple_type->elements().begin(),
659         tuple_type->elements().end(),
660         [](const TypePtr& element_type) {
661           return isTraceableType(element_type);
662         });
663   }
664 
665   if (auto dict_type = type->cast<DictType>()) {
666     return isTraceableType(dict_type->getValueType());
667   }
668 
669   return false;
670 }
671 
toTypeInferredIValue(py::handle input)672 inline IValue toTypeInferredIValue(py::handle input) {
673   auto match = tryToInferType(input);
674   if (!match.success()) {
675     auto object = py::cast<py::object>(input);
676     if (auto mod = as_module(object)) {
677       // if obj is already a ScriptModule, just return its ivalue
678       auto ptr = mod.value()._ivalue();
679       // explict copy semantics for strong ownership of the resource.
680       return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
681           ptr.release());
682     }
683 
684     // Check if the obj is a ScriptObject.
685     if (auto script_obj = as_object(object)) {
686       auto ptr = script_obj.value()._ivalue();
687       return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
688           ptr.release());
689     }
690     AT_ERROR(
691         "Tracer cannot infer type of ", py::str(input), "\n:", match.reason());
692   }
693   return toIValue(input, match.type());
694 }
695 
toTraceableStack(const py::tuple & inputs)696 inline Stack toTraceableStack(const py::tuple& inputs) {
697   auto info = toTypeInferredIValue(inputs);
698   TORCH_CHECK(
699       isTraceableType(info.type()),
700       "Type '",
701       info.type()->repr_str(),
702       "' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
703       " Tuples of Tensors can be traced");
704   return info.toTupleRef().elements().vec();
705 }
706 
707 // Serialize the python dictionary into a traceable stack.
toTraceableStack(const py::dict & inputs)708 inline Stack toTraceableStack(const py::dict& inputs) {
709   Stack res;
710   for (auto it = inputs.begin(); it != inputs.end(); it++) {
711     if (THPVariable_Check(it->second.ptr())) {
712       res.push_back(toIValue(it->second, tryToInferType(it->second).type()));
713     }
714   }
715   return res;
716 }
717 
createGenericList(py::handle obj,const TypePtr & elem_type)718 inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
719   auto elems = c10::impl::GenericList(elem_type);
720   for (auto elem : obj) {
721     elems.push_back(toIValue(elem, elem_type));
722   }
723   return IValue(elems);
724 }
725 
createGenericDict(const py::dict & obj,const TypePtr & key_type,const TypePtr & value_type)726 inline IValue createGenericDict(
727     const py::dict& obj,
728     const TypePtr& key_type,
729     const TypePtr& value_type) {
730   c10::impl::GenericDict elems(key_type, value_type);
731   elems.reserve(py::len(obj));
732   for (auto& entry : obj) {
733     elems.insert(
734         toIValue(entry.first, key_type), toIValue(entry.second, value_type));
735   }
736   return IValue(elems);
737 }
738 
739 template <class T>
guardAgainstNamedTensor(const T & var)740 inline void guardAgainstNamedTensor(const T& var) {
741   TORCH_CHECK(
742       !var.has_names(),
743       "NYI: Named tensors are currently unsupported in TorchScript. As a  "
744       "workaround please drop names via `tensor = tensor.rename(None)`.");
745 }
746 
747 // Extract custom class registered with torchbind
748 template <typename T>
toCustomClass(py::handle obj)749 c10::intrusive_ptr<T> toCustomClass(py::handle obj) {
750   static_assert(
751       std::is_base_of_v<CustomClassHolder, T>, "T is not a CustomClass");
752   const auto& type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
753   c10::IValue ivalue = toIValue(obj, type);
754   return std::move(ivalue).toCustomClass<T>();
755 }
756 
757 // Small wrapper around getting the type name string from Python to make
758 // types easier to interpret, e.g. give the structural type for a NamedTuple
friendlyTypeName(py::handle obj)759 inline std::string friendlyTypeName(py::handle obj) {
760   if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
761     auto field_names =
762         py::cast<std::vector<std::string>>(py::getattr(obj, "_fields"));
763     std::stringstream ss;
764     ss << py::str(obj.get_type().attr("__name__"));
765     ss << " (aka NamedTuple(";
766     bool first = true;
767     for (auto& field_name : field_names) {
768       if (!first) {
769         ss << ", ";
770       }
771       ss << field_name;
772       first = false;
773     }
774     ss << "))";
775     return ss.str();
776   } else {
777     return py::str(obj.get_type().attr("__name__"));
778   }
779 }
780 
781 // Thrown when trying to create a schema for a list of python
782 // arguments that cannot be converted.
783 // Can be caught by the caller to attempt to use other schema
784 // when there is an overloaded operator.
785 struct schema_match_error : public std::runtime_error {
786   using std::runtime_error::runtime_error;
787 };
788 
argumentToIValue(const FunctionSchema & schema,size_t argumentPosition,py::handle object)789 inline IValue argumentToIValue(
790     const FunctionSchema& schema,
791     size_t argumentPosition,
792     py::handle object) {
793   const auto& argument = schema.arguments().at(argumentPosition);
794   try {
795     return toIValue(object, argument.real_type(), argument.N());
796   } catch (const py::cast_error& error) {
797     throw schema_match_error(c10::str(
798         schema.formatTypeMismatchMsg(
799             argument,
800             friendlyTypeName(object),
801             argumentPosition,
802             py::repr(object)),
803         "\nCast error details: ",
804         error.what()));
805   } catch (const py::error_already_set& error) {
806     throw schema_match_error(c10::str(
807         schema.formatTypeMismatchMsg(
808             argument,
809             friendlyTypeName(object),
810             argumentPosition,
811             py::repr(object)),
812         "\n Python error details: ",
813         error.what()));
814   }
815 }
816 
returnToIValue(const TypePtr & type,py::handle object)817 inline IValue returnToIValue(const TypePtr& type, py::handle object) {
818   try {
819     return toIValue(object, type);
820   } catch (const py::cast_error& error) {
821     throw std::runtime_error(c10::str(
822         " expected value of type ",
823         type->str(),
824         " for return value but instead got value of type ",
825         py::str(object.get_type().attr("__name__")),
826         ".",
827         "\nValue: ",
828         py::repr(object),
829         "\nCast error details: ",
830         error.what()));
831   }
832 }
833 
getScriptedClassOrError(const c10::NamedTypePtr & classType)834 inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) {
835   auto py_class =
836       py::module::import("torch.jit._state")
837           .attr("_get_python_class")(classType->name()->qualifiedName());
838   if (py_class.is_none()) {
839     std::stringstream err;
840     err << "Unknown reference to ScriptClass ";
841     err << classType->name()->qualifiedName();
842     err << ". (Did you forget to import it?)";
843     throw std::runtime_error(err.str());
844   }
845   return py_class;
846 }
847 
848 struct VISIBILITY_HIDDEN tuple_slice {
tuple_slicetuple_slice849   /*implicit*/ tuple_slice(py::tuple tup_)
850       : tup(std::move(tup_)), b(0), e(tup.size()) {}
tuple_slicetuple_slice851   tuple_slice(py::tuple tup_, int64_t b_)
852       : tup(std::move(tup_)), b(b_), e(tup.size()) {}
tuple_slicetuple_slice853   tuple_slice(py::tuple tup_, int64_t b_, int64_t e_)
854       : tup(std::move(tup_)), b(b_), e(e_) {}
begintuple_slice855   py::detail::tuple_iterator begin() const {
856     return {tup, static_cast<pybind11::ssize_t>(b)};
857   }
endtuple_slice858   py::detail::tuple_iterator end() const {
859     return {tup, static_cast<pybind11::ssize_t>(e)};
860   }
sizetuple_slice861   size_t size() const {
862     return e - b;
863   }
864   py::detail::tuple_accessor operator[](size_t index) const {
865     return {tup, static_cast<size_t>(b + index)};
866   }
867 
868  private:
869   py::tuple tup;
870   int64_t b;
871   int64_t e;
872 };
873 
validateFakeScriptObjectSchema(const c10::FunctionSchema & schema,size_t argumentPosition,py::handle object)874 inline bool validateFakeScriptObjectSchema(
875     const c10::FunctionSchema& schema,
876     size_t argumentPosition,
877     py::handle object) {
878   auto argument = schema.arguments().at(argumentPosition);
879   auto class_type = argument.real_type()->expect<c10::ClassType>();
880   auto fake_class_registry =
881       py::module::import("torch._library.fake_class_registry");
882   auto fake_class = fake_class_registry.attr("find_fake_class")(
883       class_type->name().value().qualifiedName());
884   if (!py::isinstance(object.attr("wrapped_obj"), fake_class)) {
885     throw schema_match_error(c10::str(
886         schema.formatTypeMismatchMsg(
887             argument,
888             friendlyTypeName(object),
889             argumentPosition,
890             py::repr(object.attr("wrapped_obj"))),
891         "\nCast error details: ",
892         argument.name(),
893         " is expected to be a FakeScriptObject of ",
894         class_type->name().value().qualifiedName()));
895   }
896   return true;
897 }
898 
matchSchemaAllowFakeScriptObject(const FunctionSchema & schema,const tuple_slice & args,const py::kwargs & kwargs)899 inline bool matchSchemaAllowFakeScriptObject(
900     const FunctionSchema& schema,
901     const tuple_slice& args,
902     const py::kwargs& kwargs) {
903   size_t all_arguments = args.size() + kwargs.size();
904   if (all_arguments > schema.arguments().size()) {
905     throw schema_match_error(c10::str(
906         schema.name(),
907         "() expected at most ",
908         schema.arguments().size(),
909         " argument(s) but received ",
910         all_arguments,
911         " argument(s). Declaration: ",
912         schema));
913   }
914 
915   int64_t arg_idx = 0;
916   auto fake_class_registry =
917       py::module::import("torch._library.fake_class_registry");
918 
919   // First push all positional args.
920   for (const auto& arg : args) {
921     // ...but refuse to do it if the schema says that this was supposed
922     // to be keyword only
923     if (schema.arguments()[arg_idx].kwarg_only()) {
924       throw schema_match_error(c10::str(
925           schema.name(),
926           "() takes ",
927           arg_idx,
928           " positional argument(s) but ",
929           args.size(),
930           " was/were given.  Declaration: ",
931           schema));
932     }
933     // Use the type information from the schema to convert the PyObject.
934     const auto& argument = schema.arguments().at(arg_idx);
935     if (argument.real_type()->kind() == TypeKind::ClassType &&
936         py::isinstance(arg, fake_class_registry.attr("FakeScriptObject"))) {
937       validateFakeScriptObjectSchema(schema, arg_idx, arg);
938     } else {
939       argumentToIValue(schema, arg_idx, arg);
940     }
941 
942     arg_idx++;
943   }
944 
945   // Now for every remaining non-positional argument in the schema, look for it
946   // in the kwargs dict and push it if found, or use its default value if it
947   // has one.
948   size_t consumed_kwargs = 0;
949   for (size_t i = arg_idx; i < schema.arguments().size(); ++i) {
950     const auto& arg = schema.arguments()[i];
951     if (kwargs.contains(arg.name().c_str())) {
952       auto cur_kwarg = kwargs[arg.name().c_str()];
953       if (arg.real_type()->kind() == TypeKind::ClassType &&
954           py::isinstance(
955               cur_kwarg, fake_class_registry.attr("FakeScriptObject"))) {
956         validateFakeScriptObjectSchema(schema, i, cur_kwarg);
957       } else {
958         argumentToIValue(schema, i, cur_kwarg);
959       }
960       consumed_kwargs += 1;
961     } else if (arg.default_value()) {
962       continue;
963     } else {
964       throw schema_match_error(c10::str(
965           schema.name(),
966           "() is missing value for argument '",
967           arg.name(),
968           "'. Declaration: ",
969           schema));
970     }
971   }
972 
973   if (consumed_kwargs != kwargs.size()) {
974     std::vector<std::string> names;
975     for (const auto& kwarg : kwargs) {
976       names.emplace_back(py::cast<std::string>(kwarg.first));
977     }
978     throw schema_match_error(schema.findErrorInKwargs(names));
979   }
980 
981   return true;
982 }
983 
createStackForSchema(const FunctionSchema & schema,const tuple_slice & args,const py::kwargs & kwargs,std::optional<IValue> self)984 inline Stack createStackForSchema(
985     const FunctionSchema& schema,
986     const tuple_slice& args,
987     const py::kwargs& kwargs,
988     std::optional<IValue> self) {
989   size_t all_arguments = (self ? 1 : 0) + args.size() + kwargs.size();
990   if (all_arguments > schema.arguments().size()) {
991     throw schema_match_error(c10::str(
992         schema.name(),
993         "() expected at most ",
994         schema.arguments().size(),
995         " argument(s) but received ",
996         all_arguments,
997         " argument(s). Declaration: ",
998         schema));
999   }
1000   Stack stack;
1001   stack.reserve(schema.arguments().size());
1002 
1003   int64_t arg_idx = 0;
1004   if (self) {
1005     push(stack, std::move(*self));
1006     arg_idx++;
1007   }
1008   // First push all positional args.
1009   for (const auto& arg : args) {
1010     // ...but refuse to do it if the schema says that this was supposed
1011     // to be keyword only
1012     if (schema.arguments()[arg_idx].kwarg_only()) {
1013       throw schema_match_error(c10::str(
1014           schema.name(),
1015           "() takes ",
1016           arg_idx,
1017           " positional argument(s) but ",
1018           self ? 1 + args.size() : args.size(),
1019           " was/were given.  Declaration: ",
1020           schema));
1021     }
1022     // Use the type information from the schema to convert the PyObject.
1023     push(stack, argumentToIValue(schema, stack.size(), arg));
1024     arg_idx++;
1025   }
1026 
1027   // Now for every remaining non-positional argument in the schema, look for it
1028   // in the kwargs dict and push it if found, or use its default value if it
1029   // has one.
1030   size_t consumed_kwargs = 0;
1031   for (size_t i = stack.size(); i < schema.arguments().size(); ++i) {
1032     const auto& arg = schema.arguments()[i];
1033     if (kwargs.contains(arg.name().c_str())) {
1034       push(stack, argumentToIValue(schema, i, kwargs[arg.name().c_str()]));
1035       consumed_kwargs += 1;
1036     } else if (arg.default_value()) {
1037       push(stack, *arg.default_value());
1038     } else {
1039       throw schema_match_error(c10::str(
1040           schema.name(),
1041           "() is missing value for argument '",
1042           arg.name(),
1043           "'. Declaration: ",
1044           schema));
1045     }
1046   }
1047 
1048   if (consumed_kwargs != kwargs.size()) {
1049     std::vector<std::string> names;
1050     for (const auto& kwarg : kwargs) {
1051       names.emplace_back(py::cast<std::string>(kwarg.first));
1052     }
1053     throw schema_match_error(schema.findErrorInKwargs(names));
1054   }
1055 
1056   return stack;
1057 }
1058 
createPyObjectForStack(Stack && stack)1059 inline py::object createPyObjectForStack(Stack&& stack) {
1060   if (stack.empty()) {
1061     return py::none();
1062   }
1063 
1064   // Return a simple value and not a single-element tuple if there is only one
1065   // return value.
1066   if (stack.size() == 1) {
1067     return toPyObject(std::move(stack[0]));
1068   }
1069 
1070   // If there is more than one return value, pop them into a py::tuple.
1071   py::tuple return_values(stack.size());
1072   for (const auto ret : c10::irange(return_values.size())) {
1073     return_values[ret] = toPyObject(std::move(stack[ret]));
1074   }
1075 
1076   return std::move(return_values);
1077 }
1078 
1079 // TODO: Remove once we clean up the GraphExecutor usage.
1080 inline Stack evilDeprecatedBadCreateStackDoNotUse(
1081     const py::tuple& tuple,
1082     at::ArrayRef<Value*> inputs,
1083     size_t reserve_extra_space = 0) {
1084   if (tuple.size() != inputs.size()) {
1085     AT_ERROR(
1086         "expected " + std::to_string(inputs.size()) + " inputs, but got " +
1087         std::to_string(tuple.size()));
1088   }
1089   Stack result;
1090   result.reserve(tuple.size() + reserve_extra_space);
1091   for (const auto i : c10::irange(inputs.size())) {
1092     result.push_back(toIValue(std::move(tuple[i]), inputs[i]->type()));
1093   }
1094   return result;
1095 }
1096 
1097 // Run `callee`, potentially inserting a CallFunction/CallMethod node into the
1098 // tracing graph.
runAndInsertCall(Function & callee,const tuple_slice & args,const py::kwargs & kwargs,std::optional<IValue> self,const std::function<Value * (Graph &,const MatchedSchema & match)> & callInserter)1099 inline py::object runAndInsertCall(
1100     Function& callee,
1101     const tuple_slice& args,
1102     const py::kwargs& kwargs,
1103     std::optional<IValue> self,
1104     // Lambda that tells this function how to insert `callee` into the graph if
1105     // we're tracing.
1106     const std::function<Value*(Graph&, const MatchedSchema& match)>&
1107         callInserter) {
1108   auto stack =
1109       createStackForSchema(callee.getSchema(), args, kwargs, std::move(self));
1110   const auto& tracing_state = tracer::getTracingState();
1111   if (!tracing_state) {
1112     pybind11::gil_scoped_release no_gil_guard;
1113     // If we're not tracing, just run the callee as normal.
1114     callee.run(stack);
1115   } else {
1116     // If we are tracing, insert the appropriate CallFunction or CallMethod node
1117     // and then run the callee with tracing disabled.
1118 
1119     // Get the graph `Value`s that represent the input IValues
1120     auto inputs = last(stack, callee.num_inputs());
1121     auto input_values =
1122         fmap(inputs, [](const IValue& v) { return tracer::getValueTrace(v); });
1123     TORCH_INTERNAL_ASSERT(callee.getSchema().returns().size() == 1)
1124     auto return_type = callee.getSchema().returns().at(0).type();
1125     auto graph = tracing_state->graph;
1126     std::vector<NamedValue> named_values;
1127     named_values.reserve(input_values.size());
1128     for (Value* v : input_values) {
1129       named_values.emplace_back(v);
1130     }
1131 
1132     // Add a call node.
1133     MatchedSchema match = matchSchema(
1134         callee.getSchema(),
1135         tracer::getPythonInterpreterSourceRange(),
1136         *graph,
1137         named_values,
1138         {});
1139     auto output_value = callInserter(*graph, match);
1140 
1141     // Actually run the callee. Pause the tracer so that we don't double-add the
1142     // callee nodes.
1143     {
1144       pybind11::gil_scoped_release no_gil_guard;
1145       ResourceGuard guard(tracer::pauseTracing());
1146       callee.run(stack);
1147     }
1148 
1149     // Associate the output IValues with the output `Value`s in the graph
1150     tracer::setValueTrace(stack.back(), output_value);
1151   }
1152 
1153   TORCH_CHECK(
1154       !stack.empty(),
1155       "Expected values in the stack after execution but found none");
1156   return toPyObject(std::move(stack.back()));
1157 }
1158 
maybeTorchFunctionDispatch(const py::object & callee,const tuple_slice & args_no_self,const py::kwargs & kwargs,const c10::QualifiedName & qualname)1159 inline std::optional<py::object> maybeTorchFunctionDispatch(
1160     const py::object& callee,
1161     const tuple_slice& args_no_self,
1162     const py::kwargs& kwargs,
1163     const c10::QualifiedName& qualname) {
1164   std::vector<py::handle> args_vec;
1165   for (const auto& arg : args_no_self) {
1166     args_vec.push_back(arg);
1167   }
1168   py::tuple args = py::cast(args_vec);
1169 
1170   // Handle __torch_function__ dispatch
1171   std::vector<PyObject*> overloaded_args;
1172   size_t total_arg_num = args.size() + kwargs.size();
1173   for (const auto& arg : args) {
1174     is_tensor_and_append_overloaded(arg.ptr(), &overloaded_args);
1175     is_tensor_list_and_append_overloaded(
1176         arg.ptr(),
1177         &overloaded_args,
1178         static_cast<int>(total_arg_num),
1179         false /* throw_error */);
1180   }
1181   // NB: for kwargs, we cannot guarantee the order of appending
1182   // is the same as the argument order in operator's schema.
1183   // This is suboptimal, but should be fine. Later when we have
1184   // better schema matching and argument parsing, we could
1185   // match the operator in `operations` first, then the order will
1186   // be guaranteed.
1187   for (auto item : kwargs) {
1188     is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
1189     is_tensor_list_and_append_overloaded(
1190         item.second.ptr(),
1191         &overloaded_args,
1192         total_arg_num,
1193         false /* throw_error */);
1194   }
1195   if (!overloaded_args.empty()) {
1196     return pybind11::reinterpret_steal<py::object>(
1197         handle_torch_function_no_python_arg_parser(
1198             /*overloaded_args=*/overloaded_args,
1199             /*args=*/args.ptr(),
1200             /*kwargs=*/kwargs.ptr(),
1201             /*func_name=*/qualname.name().c_str(),
1202             /*torch_api_function=*/callee.ptr(),
1203             /*module_name=*/qualname.prefix().c_str()));
1204   }
1205 
1206   return std::nullopt;
1207 }
1208 
invokeScriptFunctionFromPython(Function & callee,const tuple_slice & args,const py::kwargs & kwargs)1209 inline py::object invokeScriptFunctionFromPython(
1210     Function& callee,
1211     const tuple_slice& args,
1212     const py::kwargs& kwargs) {
1213   // TODO: we could add __torch_function__ dispatch here but I don't know
1214   // the implications of doing so
1215 
1216   return runAndInsertCall(
1217       callee,
1218       args,
1219       kwargs,
1220       /*self=*/std::nullopt,
1221       [&](Graph& graph, const MatchedSchema& match) {
1222         return graph.insertFunctionCall(&callee, match);
1223       });
1224 }
1225 
invokeScriptMethodFromPython(Method & callee,const tuple_slice & args,const py::kwargs & kwargs)1226 inline py::object invokeScriptMethodFromPython(
1227     Method& callee,
1228     const tuple_slice& args,
1229     const py::kwargs& kwargs) {
1230   auto self = callee.owner()._ivalue();
1231 
1232   if (auto torch_fn_result = maybeTorchFunctionDispatch(
1233           py::cast(callee), args, kwargs, callee.name())) {
1234     return *torch_fn_result;
1235   }
1236 
1237   return runAndInsertCall(
1238       callee.function(),
1239       args,
1240       kwargs,
1241       self,
1242       [&](Graph& graph, const MatchedSchema& match) {
1243         return graph.insertMethodCall(callee.name(), match);
1244       });
1245 }
1246 
1247 TORCH_PYTHON_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
1248     const std::vector<std::shared_ptr<Operator>>& operations,
1249     const py::args& args,
1250     const py::kwargs& kwargs);
1251 
1252 TORCH_PYTHON_API py::object invokeOperatorFromPython(
1253     const std::vector<std::shared_ptr<Operator>>& operations,
1254     const py::args& args,
1255     const py::kwargs& kwargs,
1256     std::optional<c10::DispatchKey> dk = std::nullopt);
1257 
1258 TORCH_PYTHON_API std::optional<py::object> _maybe_handle_torch_function(
1259     const std::string& ns,
1260     const std::string& method_name,
1261     const std::string& overload_name,
1262     bool is_overload,
1263     const py::args& args,
1264     const py::kwargs& kwargs);
1265 
1266 TORCH_PYTHON_API bool checkSchemaAllowFakeScriptObject(
1267     const FunctionSchema& schema,
1268     const py::args& args,
1269     const py::kwargs& kwargs);
1270 
1271 TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet(
1272     const std::vector<std::shared_ptr<Operator>>& operations,
1273     Symbol symbol,
1274     const py::args& args,
1275     const py::kwargs& kwargs,
1276     bool is_overload,
1277     std::optional<c10::DispatchKey> dk = std::nullopt);
1278 
1279 } // namespace torch::jit
1280