xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/pybind_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/graph_utils.h>
2 #include <torch/csrc/jit/python/module_python.h>
3 #include <torch/csrc/jit/python/pybind_utils.h>
4 #include <torch/csrc/jit/python/python_dict.h>
5 #include <torch/csrc/jit/python/python_ivalue.h>
6 #include <torch/csrc/jit/python/python_list.h>
7 #include <torch/csrc/jit/python/utf8_decoding_ignore.h>
8 
9 #include <ATen/ScalarOps.h>
10 
11 #include <c10/core/QScheme.h>
12 #include <c10/util/irange.h>
13 #include <torch/csrc/utils/python_arg_parser.h>
14 
15 #include <limits>
16 #include <optional>
17 #include <utility>
18 
19 namespace torch::jit {
20 
21 static thread_local bool allow_numbers_as_tensors = false;
22 
ToIValueAllowNumbersAsTensors(bool enable)23 ToIValueAllowNumbersAsTensors::ToIValueAllowNumbersAsTensors(bool enable)
24     : old_(allow_numbers_as_tensors) {
25   allow_numbers_as_tensors = enable;
26 }
27 
~ToIValueAllowNumbersAsTensors()28 ToIValueAllowNumbersAsTensors::~ToIValueAllowNumbersAsTensors() {
29   allow_numbers_as_tensors = old_;
30 }
31 
32 // This is a hack to remove instances deleted in C++ from the PyBind cache
33 // C++->Python. We need this because otherwise we may get the old Python object
34 // if C++ creates a new object at the memory location of the deleted object.
clear_registered_instances(void * ptr)35 void clear_registered_instances(void* ptr) {
36   auto& registered_instances =
37       pybind11::detail::get_internals().registered_instances;
38   auto range = registered_instances.equal_range(ptr);
39   for (auto it = range.first; it != range.second; ++it) {
40     auto vh = it->second->get_value_and_holder();
41     vh.set_instance_registered(false);
42   }
43   registered_instances.erase(ptr);
44 }
45 
46 // WARNING: Precondition for this function is that, e.g., you have tested if a
47 // SymIntList is in fact only ints, and if so, you called this with T=int64_t.
48 // This precondition is NOT checked at runtime.
49 template <typename T>
listToIValue(py::handle obj)50 IValue listToIValue(py::handle obj) {
51   c10::List<T> rs;
52   for (auto it = obj.begin(); it != obj.end(); it++) {
53     auto elm = *it;
54     rs.push_back(py::cast<T>(elm));
55   }
56   // Promises that we have decayed the list appropriately
57   return c10::impl::toList<T>(rs);
58 }
59 
toIValue(py::handle obj,const TypePtr & type,std::optional<int32_t> N)60 IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
61   switch (type->kind()) {
62     case TypeKind::TensorType: {
63       if (obj.ptr() == Py_None) {
64         // None gets converted to undefined Tensors
65         return autograd::Variable();
66       }
67       if (THPVariable_Check(obj.ptr())) {
68         auto var = py::cast<autograd::Variable>(obj);
69         guardAgainstNamedTensor<autograd::Variable>(var);
70         return var;
71       } else {
72         if (!allow_numbers_as_tensors) {
73           throw py::cast_error(
74               c10::str("Unable to cast ", py::str(obj), " to Tensor"));
75         }
76         bool save_symint = false;
77         at::Scalar scalar;
78         if (PyBool_Check(obj.ptr())) {
79           scalar = at::Scalar(THPUtils_unpackBool(obj.ptr()));
80         } else if (THPUtils_checkLong(obj.ptr())) {
81           scalar = at::Scalar(THPUtils_unpackLong(obj.ptr()));
82         } else if (PyComplex_Check(obj.ptr())) {
83           scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
84         } else if (THPUtils_checkDouble(obj.ptr())) {
85           scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
86         } else if (torch::is_symint(py::handle(obj))) {
87           save_symint = true;
88           scalar = at::Scalar(7777777);
89         } else if (torch::is_symfloat(py::handle(obj))) {
90           save_symint = true;
91           scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
92         } else if (torch::is_symbool(py::handle(obj))) {
93           save_symint = true;
94           scalar = at::Scalar(true);
95         } else {
96           throw py::cast_error(
97               c10::str("Unable to cast ", py::str(obj), " to Tensor"));
98         }
99         at::Tensor tensor = at::scalar_to_tensor(scalar);
100         tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
101 
102         if (save_symint) {
103           auto py_tensor = py::cast(tensor);
104           if (PyObject_SetAttrString(
105                   py_tensor.ptr(), "_wrapped_number", obj.ptr()) < 0) {
106             throw python_error();
107           }
108         }
109 
110         return tensor;
111       }
112     }
113     case TypeKind::StorageType:
114       return py::cast<at::Storage>(obj);
115     case TypeKind::FloatType:
116       if (torch::is_symfloat(py::handle(obj))) {
117         return py::cast<c10::SymFloat>(obj).guard_float(__FILE__, __LINE__);
118       }
119       if (THPVariable_Check(obj.ptr())) {
120         auto var = py::cast<autograd::Variable>(obj);
121         // NB: We carefully test if the storage is meta, because that is
122         // always accurate even if you have a fake tensor (which is the
123         // primary case we are trying to detect here)
124         if (var.storage().device_type() == c10::kMeta) {
125           throw py::cast_error(
126               "cannot extract float from tensor with meta storage");
127         }
128       }
129       return py::cast<double>(obj);
130     case TypeKind::ComplexType: {
131       auto c_obj = py::cast<std::complex<double>>(obj.ptr());
132       return static_cast<c10::complex<double>>(c_obj);
133     }
134     case TypeKind::IntType:
135       // TODO: Properly fake this type
136       if (THPQScheme_Check(obj.ptr())) {
137         auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
138         return static_cast<uint8_t>(qscheme->qscheme);
139       }
140       // For backwards compatibility
141       if (THPDtype_Check(obj.ptr())) {
142         auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
143         return static_cast<int64_t>(dtype->scalar_type);
144       }
145       if (THPQScheme_Check(obj.ptr())) {
146         auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
147         return static_cast<uint8_t>(qscheme->qscheme);
148       }
149       if (THPLayout_Check(obj.ptr())) {
150         auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
151         return static_cast<int8_t>(layout->layout);
152       }
153       if (THPMemoryFormat_Check(obj.ptr())) {
154         auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
155         return static_cast<int8_t>(memory_format->memory_format);
156       }
157       if (torch::is_symint(py::handle(obj))) {
158         return py::cast<c10::SymInt>(obj).guard_int(__FILE__, __LINE__);
159       }
160       if (THPVariable_Check(obj.ptr())) {
161         auto var = py::cast<autograd::Variable>(obj);
162         if (var.storage().device_type() == c10::kMeta) {
163           throw py::cast_error(
164               "cannot extract int from tensor with meta storage");
165         }
166       }
167       return py::cast<int64_t>(obj);
168     case TypeKind::LayoutType: {
169       if (THPLayout_Check(obj.ptr())) {
170         auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
171         return static_cast<int8_t>(layout->layout);
172       }
173       // For backwards compatibility
174       return py::cast<int64_t>(obj);
175     }
176     case TypeKind::ScalarTypeType: {
177       if (THPDtype_Check(obj.ptr())) {
178         auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
179         return static_cast<int64_t>(dtype->scalar_type);
180       }
181       // For backwards compatibility
182       return py::cast<int64_t>(obj);
183     }
184     case TypeKind::MemoryFormatType: {
185       if (THPMemoryFormat_Check(obj.ptr())) {
186         auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
187         return static_cast<int8_t>(memory_format->memory_format);
188       }
189       // For backwards compatibility
190       return py::cast<int64_t>(obj);
191     }
192     case TypeKind::SymIntType:
193       if (torch::is_symint(obj.ptr())) {
194         return py::cast<c10::SymInt>(obj);
195       }
196       return py::cast<int64_t>(obj);
197     case TypeKind::SymFloatType:
198       if (torch::is_symfloat(obj.ptr())) {
199         return py::cast<c10::SymFloat>(obj);
200       }
201       return py::cast<double>(obj);
202     case TypeKind::SymBoolType:
203       if (torch::is_symbool(obj.ptr())) {
204         return py::cast<c10::SymBool>(obj);
205       }
206       return py::cast<bool>(obj);
207     case TypeKind::NoneType:
208       if (!obj.is_none()) {
209         throw py::cast_error(
210             c10::str("Cannot cast ", py::str(obj), " to None"));
211       }
212       return {};
213     case TypeKind::BoolType:
214       if (torch::is_symbool(obj.ptr())) {
215         return py::cast<c10::SymBool>(obj).guard_bool(__FILE__, __LINE__);
216       }
217       if (THPVariable_Check(obj.ptr())) {
218         auto var = py::cast<autograd::Variable>(obj);
219         if (var.storage().device_type() == c10::kMeta) {
220           throw py::cast_error(
221               "cannot extract bool from tensor with meta storage");
222         }
223       }
224       return py::cast<bool>(obj);
225     case TypeKind::TupleType: {
226       py::tuple tuple = py::cast<py::tuple>(obj);
227       size_t tuple_size = tuple.size();
228       auto tuple_type = type->cast<TupleType>();
229       const auto& elem_types = tuple_type->elements();
230       if (elem_types.size() != tuple_size) {
231         throw py::cast_error(c10::str(
232             "Object ",
233             py::str(obj),
234             " had a different number of elements than type ",
235             type->repr_str()));
236       }
237       std::vector<IValue> values;
238       values.reserve(tuple_size);
239       for (const auto i : c10::irange(tuple_size)) {
240         values.push_back(toIValue(tuple[i], elem_types[i]));
241       }
242       return tuple_type->name()
243           ? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type)
244           : c10::ivalue::Tuple::create(std::move(values));
245     }
246     case TypeKind::UnionType: {
247       auto actual_type = toTypeInferredIValue(obj);
248       auto actual_type_ptr = actual_type.type();
249       auto union_type = type->expect<UnionType>();
250       if (!actual_type_ptr->isSubtypeOf(union_type)) {
251         throw py::cast_error(c10::str(
252             "Expected a member of ",
253             union_type->annotation_str(),
254             " but instead found type ",
255             actual_type.type()->annotation_str()));
256       }
257       return actual_type;
258     }
259     case TypeKind::StringType:
260       return ConstantString::create(py::cast<std::string>(obj));
261     case TypeKind::DeviceObjType: {
262       if (THPDevice_Check(obj.ptr())) {
263         auto device = reinterpret_cast<THPDevice*>(obj.ptr());
264         return device->device;
265       }
266       return c10::Device(py::cast<std::string>(obj.ptr()));
267     }
268     case TypeKind::StreamObjType: {
269       auto thp_stream = reinterpret_cast<THPStream*>(obj.ptr());
270       auto stream = c10::Stream::unpack3(
271           thp_stream->stream_id,
272           static_cast<c10::DeviceIndex>(thp_stream->device_index),
273           static_cast<c10::DeviceType>(thp_stream->device_type));
274       return stream;
275     }
276     case TypeKind::ListType: {
277       // If the object is a ScriptList, retrieve the c10::List
278       // instance inside it.
279       if (py::isinstance<ScriptList>(obj)) {
280         return py::cast<ScriptList>(obj).list_;
281       }
282 
283       // If not (i.e. it is a regular Python list), make a new
284       // c10::List.
285       const auto& elem_type = type->expectRef<ListType>().getElementType();
286       switch (elem_type->kind()) {
287         // allows single int/float to be broadcasted to a fixed size list
288         case TypeKind::IntType:
289           if (!N || !py::isinstance<py::int_>(obj)) {
290             return IValue(py::cast<std::vector<int64_t>>(obj));
291           } else {
292             int64_t value = py::cast<int64_t>(obj);
293             c10::List<int64_t> repeated;
294             repeated.reserve(*N);
295             for (int i = 0; i < *N; ++i) {
296               repeated.push_back(value);
297             }
298             return repeated;
299           }
300         case TypeKind::SymIntType: {
301           bool is_symbolic = false;
302           for (auto it = obj.begin(); it != obj.end(); it++) {
303             auto elm = *it;
304             if (torch::is_symint(elm)) {
305               is_symbolic = true;
306               break;
307             }
308           }
309           if (is_symbolic) {
310             return listToIValue<c10::SymInt>(obj);
311           } else {
312             return listToIValue<int64_t>(obj);
313           }
314         }
315         case TypeKind::SymFloatType: {
316           bool is_symbolic = false;
317           for (auto it = obj.begin(); it != obj.end(); it++) {
318             auto elm = *it;
319             // TODO: what about SymInt conversion to SymFloat?
320             if (torch::is_symfloat(elm)) {
321               is_symbolic = true;
322               break;
323             }
324           }
325           if (is_symbolic) {
326             return listToIValue<c10::SymFloat>(obj);
327           } else {
328             return listToIValue<double>(obj);
329           }
330         }
331         case TypeKind::SymBoolType: {
332           bool is_symbolic = false;
333           for (auto it = obj.begin(); it != obj.end(); it++) {
334             auto elm = *it;
335             if (torch::is_symbool(elm)) {
336               is_symbolic = true;
337               break;
338             }
339           }
340           if (is_symbolic) {
341             return listToIValue<c10::SymBool>(obj);
342           } else {
343             return listToIValue<bool>(obj);
344           }
345         }
346         case TypeKind::FloatType:
347           if (!N || !py::isinstance<py::float_>(obj)) {
348             return IValue(py::cast<std::vector<double>>(obj));
349           } else {
350             double value = py::cast<double>(obj);
351             c10::List<double> repeated;
352             repeated.reserve(*N);
353             for (int i = 0; i < *N; ++i) {
354               repeated.push_back(value);
355             }
356             return repeated;
357           }
358         case TypeKind::BoolType:
359           return IValue(py::cast<std::vector<bool>>(obj));
360         case TypeKind::TensorType:
361           return IValue(py::cast<std::vector<at::Tensor>>(obj));
362         default:
363           return createGenericList(obj, elem_type);
364       }
365     }
366     case TypeKind::DictType: {
367       const auto& dict_type = type->expect<DictType>();
368 
369       // If the object is a ScriptDict, retrieve the c10::Dict
370       // instance inside it.
371       try {
372         auto script_dict = py::cast<ScriptDict>(obj);
373         return script_dict.dict_;
374       } catch (py::cast_error& e) {
375       }
376 
377       // If not (i.e. it is a regular Python dictionary), make a new
378       // c10::Dict.
379       return createGenericDict(
380           py::cast<py::dict>(obj),
381           dict_type->getKeyType(),
382           dict_type->getValueType());
383     }
384     case TypeKind::OptionalType: {
385       // check if it's a none obj since optional accepts NoneType
386       if (obj.is_none()) {
387         // check if it's a none obj since optional accepts NoneType
388         // return an IValue() to denote a NoneType
389         return {};
390       }
391       return toIValue(obj, type->expectRef<OptionalType>().getElementType(), N);
392     }
393     case TypeKind::ClassType: {
394       auto classType = type->expect<ClassType>();
395       auto object = py::cast<py::object>(obj);
396       if (auto mod = as_module(object)) {
397         // if obj is already a ScriptModule, just return its ivalue
398         return mod.value()._ivalue();
399       }
400 
401       // Check if the obj is a ScriptObject.
402       if (auto script_obj = as_object(object)) {
403         return script_obj.value()._ivalue();
404       }
405 
406       // otherwise is a normal class object, we create a fresh
407       // ivalue::Object to use from the py object.
408       // 1. create a bare ivalue
409       const size_t numAttrs = classType->numAttributes();
410       auto cu = classType->compilation_unit();
411       auto userObj = c10::ivalue::Object::create(
412           c10::StrongTypePtr(cu, classType), numAttrs);
413 
414       // 2. copy all the contained types
415       for (const auto slot : c10::irange(numAttrs)) {
416         const auto& attrType = classType->getAttribute(slot);
417         const auto& attrName = classType->getAttributeName(slot);
418 
419         if (!py::hasattr(obj, attrName.c_str())) {
420           throw py::cast_error(c10::str(
421               "Tried to cast object to type ",
422               type->repr_str(),
423               " but object",
424               " was missing attribute ",
425               attrName));
426         }
427 
428         try {
429           const auto& contained = py::getattr(obj, attrName.c_str());
430           userObj->setSlot(slot, toIValue(contained, attrType));
431         } catch (std::exception& e) {
432           throw py::cast_error(c10::str(
433               "Could not cast attribute '",
434               attrName,
435               "' to type ",
436               attrType->repr_str(),
437               ": ",
438               e.what()));
439         }
440       }
441       return userObj;
442     }
443     case TypeKind::InterfaceType: {
444       auto interfaceType = type->expect<InterfaceType>();
445       // When converting an pyobj to an interface, we check if rhs
446       // is module or normal torchscript class, get the type and ivalue
447       // from them correspondingly.
448       c10::ClassTypePtr classType = nullptr;
449       IValue res;
450       if (auto mod = as_module(py::cast<py::object>(obj))) {
451         classType = mod.value().type();
452         res = mod.value()._ivalue();
453       } else if (auto object = as_object(py::cast<py::object>(obj))) {
454         classType = object.value().type();
455         res = object.value()._ivalue();
456       } else {
457         // We inspect the value to found the compiled TorchScript class
458         // and then create a ivalue::Object from that class type.
459         py::str qualified_name = py::module::import("torch._jit_internal")
460                                      .attr("_qualified_name")(obj.get_type());
461         auto pyCu = get_python_cu();
462         classType = pyCu->get_class(c10::QualifiedName(qualified_name));
463         if (!classType) {
464           throw std::runtime_error(c10::str(
465               "Assigning the object ",
466               py::str(obj),
467               " to an interface fails because the value is not "
468               "a TorchScript compatible type, did you forget to",
469               "turn it into a user defined TorchScript class?"));
470         }
471         res = toIValue(obj, classType);
472       }
473       // check if the classType conform with the interface or not
474       std::stringstream why_not;
475       if (!classType->isSubtypeOfExt(*interfaceType, &why_not)) {
476         throw py::cast_error(c10::str(
477             "Object of type ",
478             classType->repr_str(),
479             " is not compatible with interface ",
480             interfaceType->repr_str(),
481             "\n",
482             why_not.str()));
483       }
484       return res;
485     }
486     case TypeKind::NumberType: {
487       if (THPDtype_Check(obj.ptr())) {
488         auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
489         return static_cast<int64_t>(dtype->scalar_type);
490       }
491       if (THPQScheme_Check(obj.ptr())) {
492         auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
493         return static_cast<uint8_t>(qscheme->qscheme);
494       }
495       if (THPLayout_Check(obj.ptr())) {
496         auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
497         return static_cast<int8_t>(layout->layout);
498       }
499       if (py::isinstance<py::bool_>(obj)) {
500         return py::cast<bool>(obj);
501       } else if (py::isinstance<py::int_>(obj)) {
502         return py::cast<int64_t>(obj);
503       } else if (py::isinstance<py::float_>(obj)) {
504         return py::cast<double>(obj);
505       } else if (PyComplex_CheckExact(obj.ptr())) {
506         auto c_obj = py::cast<std::complex<double>>(obj.ptr());
507         return static_cast<c10::complex<double>>(c_obj);
508       } else if (torch::is_symint(obj)) {
509         return py::cast<c10::SymInt>(obj);
510       } else if (torch::is_symfloat(obj)) {
511         return py::cast<c10::SymFloat>(obj);
512       } else if (torch::is_symbool(obj)) {
513         return py::cast<c10::SymBool>(obj);
514       } else {
515         throw py::cast_error(
516             c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
517       }
518     }
519     case TypeKind::RRefType: {
520 #ifdef USE_RPC
521       return obj.cast<torch::distributed::rpc::PyRRef>().toIValue();
522 #else
523       AT_ERROR("RRef is only supported with the distributed package");
524 #endif
525     } break;
526     case TypeKind::PyObjectType: {
527       return c10::ivalue::ConcretePyObjectHolder::create(obj);
528     }
529     case TypeKind::CapsuleType: {
530       return IValue::make_capsule(py::cast<c10::Capsule>(obj).obj_ptr);
531     }
532     case TypeKind::FutureType: {
533       return obj.cast<std::shared_ptr<PythonFutureWrapper>>()->fut;
534     }
535     case TypeKind::AwaitType: {
536       return obj.cast<std::shared_ptr<PythonAwaitWrapper>>()->aw_;
537     }
538     case TypeKind::AnyType:
539       return toTypeInferredIValue(obj);
540     case TypeKind::QSchemeType: {
541       if (py::isinstance<py::int_>(obj)) {
542         return static_cast<at::QScheme>(py::cast<int64_t>(obj));
543       }
544       throw py::cast_error(
545           c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
546     }
547     case TypeKind::GeneratorType:
548       return py::cast<at::Generator>(obj);
549     case TypeKind::DynamicType:
550     case TypeKind::FunctionType:
551     case TypeKind::QuantizerType:
552     case TypeKind::VarType:
553     case TypeKind::AnyListType:
554     case TypeKind::AnyTupleType:
555     case TypeKind::AnyClassType:
556     case TypeKind::AnyEnumType:
557       break;
558     case TypeKind::EnumType:
559       EnumTypePtr enum_type = type->expect<EnumType>();
560       py::object py_obj = py::reinterpret_borrow<py::object>(obj);
561       std::string name = py::cast<std::string>(obj.attr("name"));
562       IValue value = toIValue(obj.attr("value"), enum_type->getValueType(), {});
563       auto enum_holder =
564           c10::make_intrusive<c10::ivalue::EnumHolder>(enum_type, name, value);
565       return IValue(enum_holder);
566   }
567   throw py::cast_error(c10::str(
568       "toIValue() cannot handle converting to type: ", type->repr_str()));
569 }
570 
toPyObject(IValue ivalue)571 py::object toPyObject(IValue ivalue) {
572   if (ivalue.isNone()) {
573     return py::none();
574   } else if (ivalue.isTensor()) {
575     auto tensor = std::move(ivalue).toTensor();
576     if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
577       TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
578       auto py_tensor = py::cast(tensor);
579       if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) {
580         return py_tensor.attr("_wrapped_number");
581       }
582       auto scalar_type = tensor.scalar_type();
583       switch (scalar_type) {
584         case at::ScalarType::Bool:
585           return py::cast(*tensor.const_data_ptr<bool>());
586         case at::ScalarType::Long:
587           return py::cast(*tensor.const_data_ptr<int64_t>());
588         case at::ScalarType::Double:
589           return py::cast(*tensor.const_data_ptr<double>());
590         case at::ScalarType::ComplexDouble:
591           // TODO: https://github.com/pytorch/pytorch/issues/77134
592           return py::cast(static_cast<std::complex<double>>(
593               *tensor.const_data_ptr<c10::complex<double>>()));
594         default:
595           TORCH_CHECK(
596               false,
597               "Missing cases in 'toPyObject' wrapped number handling! Can't convert ",
598               scalar_type,
599               " to a Python object");
600       }
601     } else {
602       guardAgainstNamedTensor<at::Tensor>(tensor);
603       return py::cast(autograd::Variable(std::move(tensor)));
604     }
605   } else if (ivalue.isStorage()) {
606     return py::cast(std::move(ivalue).toStorage());
607   } else if (ivalue.isGenerator()) {
608     return py::cast(std::move(ivalue).toGenerator());
609   } else if (ivalue.isDouble()) {
610     return py::cast(std::move(ivalue).toDouble());
611   } else if (ivalue.isComplexDouble()) {
612     return py::cast(
613         static_cast<std::complex<double>>(std::move(ivalue).toComplexDouble()));
614   } else if (ivalue.isInt()) {
615     return py::cast(std::move(ivalue).toInt());
616   } else if (ivalue.isBool()) {
617     return py::cast(std::move(ivalue).toBool());
618   } else if (ivalue.isString()) {
619     if (getUTF8DecodingIgnore()) {
620       std::string s = std::move(ivalue).toStringRef();
621       PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore");
622       return py::reinterpret_steal<py::object>(pyObj);
623     } else {
624       return py::cast(std::move(ivalue).toStringRef());
625     }
626   } else if (ivalue.isList()) {
627     auto list = std::move(ivalue).toList();
628     py::list t{list.size()};
629     for (const auto i : c10::irange(list.size())) {
630       t[i] = toPyObject(IValue{list.get(i)});
631     }
632     return std::move(t);
633   } else if (ivalue.isTuple()) {
634     auto tuple = std::move(ivalue).toTuple();
635     const auto& elements = tuple->elements();
636 
637     py::tuple t{elements.size()};
638     for (const auto i : c10::irange(elements.size())) {
639       t[i] = toPyObject(IValue{elements.at(i)});
640     }
641 
642     // If we have a NamedTuple
643     if (tuple->type() && tuple->type()->schema() &&
644         !tuple->type()->schema()->name().empty()) {
645       auto unqualName = tuple->type()->name()->name();
646 
647       std::vector<Argument> tuple_args = tuple->type()->schema()->arguments();
648 
649       std::vector<pybind11::object> defaults;
650       auto it = std::find_if(
651           tuple_args.begin(), tuple_args.end(), [](const Argument& arg) {
652             return arg.default_value().has_value();
653           });
654       std::transform(
655           it,
656           tuple_args.end(),
657           std::back_inserter(defaults),
658           [](const Argument& arg) { return toPyObject(*arg.default_value()); });
659 
660       std::vector<std::string> fieldNames =
661           fmap(tuple_args, [](const Argument& arg) { return arg.name(); });
662 
663       return py::module::import("torch._jit_internal")
664           .attr("_create_named_tuple")(
665               t, unqualName, fieldNames, py::make_tuple(defaults));
666     } else {
667       return std::move(t);
668     }
669   } else if (ivalue.isDevice()) {
670     return py::cast(std::move(ivalue).toDevice());
671   } else if (ivalue.isStream()) {
672     return py::cast(std::move(ivalue).toStream());
673   } else if (ivalue.isGenericDict()) {
674     auto dict = std::move(ivalue).toGenericDict();
675     py::dict py_dict;
676     for (auto& pair : dict) {
677       py_dict[toPyObject(IValue{pair.key()})] =
678           toPyObject(IValue{pair.value()});
679     }
680     return std::move(py_dict);
681   } else if (ivalue.isRRef()) {
682 #ifdef USE_RPC
683     auto RRefPtr =
684         c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
685             std::move(ivalue).toRRef());
686     return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
687 #else
688     AT_ERROR("RRef is only supported with the distributed package");
689 #endif
690   } else if (ivalue.isObject()) {
691     const auto obj = std::move(ivalue).toObject();
692     if (obj->type()->is_module()) {
693       return py::cast(Module(obj));
694     }
695 
696     auto pyCu = get_python_cu();
697     if (obj->name().find("__torch__.torch.classes") == 0) {
698       return py::cast(Object(obj));
699     }
700     const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
701     AT_ASSERT(classType, c10::str(obj->name(), " is not found."));
702     auto pyClass = getScriptedClassOrError(obj->type());
703     auto pyObj = pyClass.attr("__new__")(pyClass);
704 
705     const auto numAttrs = classType->numAttributes();
706 
707     for (const auto slot : c10::irange(numAttrs)) {
708       const auto& attrName = classType->getAttributeName(slot);
709       IValue v = obj->getSlot(slot);
710       py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v)));
711     }
712     return pyObj;
713   } else if (ivalue.isPyObject()) {
714     // return borrowed reference to ensure it correctly incref the underlying
715     // PyObject
716     return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
717   } else if (ivalue.isCapsule()) {
718     return py::cast(c10::Capsule(ivalue.toCapsule()));
719   } else if (ivalue.isFuture()) {
720     return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
721   } else if (ivalue.isAwait()) {
722     return py::cast(std::make_shared<PythonAwaitWrapper>(ivalue.toAwait()));
723   } else if (ivalue.isEnum()) {
724     auto enum_holder = ivalue.toEnumHolder();
725     auto py_class = getScriptedClassOrError(enum_holder->type());
726     return py_class.attr(enum_holder->name().c_str());
727   } else if (ivalue.isRRef()) {
728 #ifdef USE_RPC
729     return py::cast(torch::distributed::rpc::PyRRef(
730         c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
731             ivalue.toRRef())));
732 #else
733     TORCH_CHECK(false, "RRef is only supported with the distributed package");
734 #endif
735   } else if (ivalue.isSymInt()) {
736     return py::cast(std::move(ivalue).toSymInt());
737   } else if (ivalue.isSymFloat()) {
738     return py::cast(std::move(ivalue).toSymFloat());
739   } else if (ivalue.isSymBool()) {
740     return py::cast(std::move(ivalue).toSymBool());
741   } else {
742     AT_ERROR(
743         "Missing cases in 'toPyObject'! Can't convert ",
744         ivalue.tagKind(),
745         " to a Python object");
746   }
747 }
748 
getOpWithStack(const std::vector<std::shared_ptr<Operator>> & operations,const py::args & args,const py::kwargs & kwargs)749 std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
750     const std::vector<std::shared_ptr<Operator>>& operations,
751     const py::args& args,
752     const py::kwargs& kwargs) {
753   Stack stack;
754   if (operations.size() == 1) {
755     std::shared_ptr<Operator> op = operations.at(0);
756     // Create a stack full of the arguments and keyword arguments.
757     stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt);
758 
759     return std::make_pair(std::move(op), std::move(stack));
760   } else {
761     std::vector<schema_match_error> errors;
762     std::shared_ptr<Operator> found_op = nullptr;
763     for (const auto& op : operations) {
764       try {
765         stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt);
766         found_op = op;
767         break;
768       } catch (schema_match_error& error) {
769         errors.push_back(std::move(error));
770       }
771     }
772     if (!found_op) {
773       std::stringstream ss;
774       ss << "Overloaded torch operator invoked from Python failed to match any schema:\n";
775       for (const auto& err : errors) {
776         ss << err.what() << "\n\n";
777       }
778       throw std::runtime_error(ss.str());
779     }
780 
781     return std::make_pair(std::move(found_op), std::move(stack));
782   }
783 }
784 
785 // This function is used to check if the schema is valid for the given args and
786 // kwargs. It checks script object by checking wether the FakeScriptObject is
787 // an instance of the corresponding fake class for the actual class used in
788 // schema.
checkSchemaAllowFakeScriptObject(const FunctionSchema & schema,const py::args & args,const py::kwargs & kwargs)789 bool checkSchemaAllowFakeScriptObject(
790     const FunctionSchema& schema,
791     const py::args& args,
792     const py::kwargs& kwargs) {
793   bool match = false;
794   try {
795     match = matchSchemaAllowFakeScriptObject(schema, args, kwargs);
796   } catch (schema_match_error& error) {
797     throw std::runtime_error(error.what());
798   }
799   return match;
800 }
801 
invokeOperatorFromPython(const std::vector<std::shared_ptr<Operator>> & operations,const py::args & args,const py::kwargs & kwargs,std::optional<c10::DispatchKey> dk)802 py::object invokeOperatorFromPython(
803     const std::vector<std::shared_ptr<Operator>>& operations,
804     const py::args& args,
805     const py::kwargs& kwargs,
806     std::optional<c10::DispatchKey> dk) {
807   auto [found_op, stack] = getOpWithStack(operations, args, kwargs);
808   {
809     pybind11::gil_scoped_release no_gil_guard;
810     if (dk) {
811       found_op->getOperationForDispatchKey (*dk)(stack);
812     } else {
813       found_op->getOperation()(stack);
814     }
815   }
816 
817   return createPyObjectForStack(std::move(stack));
818 }
819 
_maybe_handle_torch_function(const std::string & ns,const std::string & method_name,const std::string & overload_name,bool is_overload,const py::args & args,const py::kwargs & kwargs)820 std::optional<py::object> _maybe_handle_torch_function(
821     const std::string& ns,
822     const std::string& method_name,
823     const std::string& overload_name,
824     bool is_overload,
825     const py::args& args,
826     const py::kwargs& kwargs) {
827   std::vector<PyObject*> overloaded_args;
828   size_t total_arg_num = args.size() + kwargs.size();
829   for (const auto i : c10::irange(args.size())) {
830     is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
831     is_tensor_list_and_append_overloaded(
832         args[i].ptr(),
833         &overloaded_args,
834         static_cast<int>(total_arg_num),
835         false /* throw_error */);
836   }
837   // NB: for kwargs, we cannot guarantee the order of appending
838   // is the same as the argument order in operator's schema.
839   // This is suboptimal, but should be fine. Later when we have
840   // better schema matching and argument parsing, we could
841   // match the operator in `operations` first, then the order will
842   // be guaranteed.
843   for (auto item : kwargs) {
844     is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
845     is_tensor_list_and_append_overloaded(
846         item.second.ptr(),
847         &overloaded_args,
848         total_arg_num,
849         false /* throw_error */);
850   }
851   if (!overloaded_args.empty() || at::impl::torch_function_mode_enabled()) {
852     auto self_func = py::module::import("torch")
853                          .attr("ops")
854                          .attr(ns.c_str())
855                          .attr(method_name.c_str());
856     if (is_overload) {
857       if (overload_name.empty()) {
858         self_func = self_func.attr("default");
859       } else {
860         self_func = self_func.attr(overload_name.c_str());
861       }
862     }
863     std::string module_name("torch.ops");
864     module_name.append(ns);
865     return {pybind11::reinterpret_steal<py::object>(
866         handle_torch_function_no_python_arg_parser(
867             overloaded_args,
868             args.ptr(),
869             kwargs.ptr(),
870             method_name.c_str(),
871             self_func.ptr(),
872             module_name.c_str()))};
873   }
874   return std::nullopt;
875 }
876 
_get_operation_for_overload_or_packet(const std::vector<std::shared_ptr<Operator>> & operations,Symbol symbol,const py::args & args,const py::kwargs & kwargs,bool is_overload,std::optional<c10::DispatchKey> dk)877 py::object _get_operation_for_overload_or_packet(
878     const std::vector<std::shared_ptr<Operator>>& operations,
879     Symbol symbol,
880     const py::args& args,
881     const py::kwargs& kwargs,
882     bool is_overload,
883     std::optional<c10::DispatchKey> dk) {
884   std::string ns = symbol.ns().toUnqualString();
885   std::string method_name = symbol.toUnqualString();
886   std::string overload_name = operations[0]->schema().overload_name();
887   auto res = _maybe_handle_torch_function(
888       ns, method_name, overload_name, is_overload, args, kwargs);
889   auto torch_function_called = res.has_value();
890   return torch_function_called
891       ? *res
892       : invokeOperatorFromPython(operations, args, kwargs, dk);
893 }
894 
895 } // namespace torch::jit
896