1 #include <torch/csrc/jit/ir/graph_utils.h>
2 #include <torch/csrc/jit/python/module_python.h>
3 #include <torch/csrc/jit/python/pybind_utils.h>
4 #include <torch/csrc/jit/python/python_dict.h>
5 #include <torch/csrc/jit/python/python_ivalue.h>
6 #include <torch/csrc/jit/python/python_list.h>
7 #include <torch/csrc/jit/python/utf8_decoding_ignore.h>
8
9 #include <ATen/ScalarOps.h>
10
11 #include <c10/core/QScheme.h>
12 #include <c10/util/irange.h>
13 #include <torch/csrc/utils/python_arg_parser.h>
14
15 #include <limits>
16 #include <optional>
17 #include <utility>
18
19 namespace torch::jit {
20
21 static thread_local bool allow_numbers_as_tensors = false;
22
ToIValueAllowNumbersAsTensors(bool enable)23 ToIValueAllowNumbersAsTensors::ToIValueAllowNumbersAsTensors(bool enable)
24 : old_(allow_numbers_as_tensors) {
25 allow_numbers_as_tensors = enable;
26 }
27
~ToIValueAllowNumbersAsTensors()28 ToIValueAllowNumbersAsTensors::~ToIValueAllowNumbersAsTensors() {
29 allow_numbers_as_tensors = old_;
30 }
31
32 // This is a hack to remove instances deleted in C++ from the PyBind cache
33 // C++->Python. We need this because otherwise we may get the old Python object
34 // if C++ creates a new object at the memory location of the deleted object.
clear_registered_instances(void * ptr)35 void clear_registered_instances(void* ptr) {
36 auto& registered_instances =
37 pybind11::detail::get_internals().registered_instances;
38 auto range = registered_instances.equal_range(ptr);
39 for (auto it = range.first; it != range.second; ++it) {
40 auto vh = it->second->get_value_and_holder();
41 vh.set_instance_registered(false);
42 }
43 registered_instances.erase(ptr);
44 }
45
46 // WARNING: Precondition for this function is that, e.g., you have tested if a
47 // SymIntList is in fact only ints, and if so, you called this with T=int64_t.
48 // This precondition is NOT checked at runtime.
49 template <typename T>
listToIValue(py::handle obj)50 IValue listToIValue(py::handle obj) {
51 c10::List<T> rs;
52 for (auto it = obj.begin(); it != obj.end(); it++) {
53 auto elm = *it;
54 rs.push_back(py::cast<T>(elm));
55 }
56 // Promises that we have decayed the list appropriately
57 return c10::impl::toList<T>(rs);
58 }
59
toIValue(py::handle obj,const TypePtr & type,std::optional<int32_t> N)60 IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
61 switch (type->kind()) {
62 case TypeKind::TensorType: {
63 if (obj.ptr() == Py_None) {
64 // None gets converted to undefined Tensors
65 return autograd::Variable();
66 }
67 if (THPVariable_Check(obj.ptr())) {
68 auto var = py::cast<autograd::Variable>(obj);
69 guardAgainstNamedTensor<autograd::Variable>(var);
70 return var;
71 } else {
72 if (!allow_numbers_as_tensors) {
73 throw py::cast_error(
74 c10::str("Unable to cast ", py::str(obj), " to Tensor"));
75 }
76 bool save_symint = false;
77 at::Scalar scalar;
78 if (PyBool_Check(obj.ptr())) {
79 scalar = at::Scalar(THPUtils_unpackBool(obj.ptr()));
80 } else if (THPUtils_checkLong(obj.ptr())) {
81 scalar = at::Scalar(THPUtils_unpackLong(obj.ptr()));
82 } else if (PyComplex_Check(obj.ptr())) {
83 scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
84 } else if (THPUtils_checkDouble(obj.ptr())) {
85 scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
86 } else if (torch::is_symint(py::handle(obj))) {
87 save_symint = true;
88 scalar = at::Scalar(7777777);
89 } else if (torch::is_symfloat(py::handle(obj))) {
90 save_symint = true;
91 scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
92 } else if (torch::is_symbool(py::handle(obj))) {
93 save_symint = true;
94 scalar = at::Scalar(true);
95 } else {
96 throw py::cast_error(
97 c10::str("Unable to cast ", py::str(obj), " to Tensor"));
98 }
99 at::Tensor tensor = at::scalar_to_tensor(scalar);
100 tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
101
102 if (save_symint) {
103 auto py_tensor = py::cast(tensor);
104 if (PyObject_SetAttrString(
105 py_tensor.ptr(), "_wrapped_number", obj.ptr()) < 0) {
106 throw python_error();
107 }
108 }
109
110 return tensor;
111 }
112 }
113 case TypeKind::StorageType:
114 return py::cast<at::Storage>(obj);
115 case TypeKind::FloatType:
116 if (torch::is_symfloat(py::handle(obj))) {
117 return py::cast<c10::SymFloat>(obj).guard_float(__FILE__, __LINE__);
118 }
119 if (THPVariable_Check(obj.ptr())) {
120 auto var = py::cast<autograd::Variable>(obj);
121 // NB: We carefully test if the storage is meta, because that is
122 // always accurate even if you have a fake tensor (which is the
123 // primary case we are trying to detect here)
124 if (var.storage().device_type() == c10::kMeta) {
125 throw py::cast_error(
126 "cannot extract float from tensor with meta storage");
127 }
128 }
129 return py::cast<double>(obj);
130 case TypeKind::ComplexType: {
131 auto c_obj = py::cast<std::complex<double>>(obj.ptr());
132 return static_cast<c10::complex<double>>(c_obj);
133 }
134 case TypeKind::IntType:
135 // TODO: Properly fake this type
136 if (THPQScheme_Check(obj.ptr())) {
137 auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
138 return static_cast<uint8_t>(qscheme->qscheme);
139 }
140 // For backwards compatibility
141 if (THPDtype_Check(obj.ptr())) {
142 auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
143 return static_cast<int64_t>(dtype->scalar_type);
144 }
145 if (THPQScheme_Check(obj.ptr())) {
146 auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
147 return static_cast<uint8_t>(qscheme->qscheme);
148 }
149 if (THPLayout_Check(obj.ptr())) {
150 auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
151 return static_cast<int8_t>(layout->layout);
152 }
153 if (THPMemoryFormat_Check(obj.ptr())) {
154 auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
155 return static_cast<int8_t>(memory_format->memory_format);
156 }
157 if (torch::is_symint(py::handle(obj))) {
158 return py::cast<c10::SymInt>(obj).guard_int(__FILE__, __LINE__);
159 }
160 if (THPVariable_Check(obj.ptr())) {
161 auto var = py::cast<autograd::Variable>(obj);
162 if (var.storage().device_type() == c10::kMeta) {
163 throw py::cast_error(
164 "cannot extract int from tensor with meta storage");
165 }
166 }
167 return py::cast<int64_t>(obj);
168 case TypeKind::LayoutType: {
169 if (THPLayout_Check(obj.ptr())) {
170 auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
171 return static_cast<int8_t>(layout->layout);
172 }
173 // For backwards compatibility
174 return py::cast<int64_t>(obj);
175 }
176 case TypeKind::ScalarTypeType: {
177 if (THPDtype_Check(obj.ptr())) {
178 auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
179 return static_cast<int64_t>(dtype->scalar_type);
180 }
181 // For backwards compatibility
182 return py::cast<int64_t>(obj);
183 }
184 case TypeKind::MemoryFormatType: {
185 if (THPMemoryFormat_Check(obj.ptr())) {
186 auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
187 return static_cast<int8_t>(memory_format->memory_format);
188 }
189 // For backwards compatibility
190 return py::cast<int64_t>(obj);
191 }
192 case TypeKind::SymIntType:
193 if (torch::is_symint(obj.ptr())) {
194 return py::cast<c10::SymInt>(obj);
195 }
196 return py::cast<int64_t>(obj);
197 case TypeKind::SymFloatType:
198 if (torch::is_symfloat(obj.ptr())) {
199 return py::cast<c10::SymFloat>(obj);
200 }
201 return py::cast<double>(obj);
202 case TypeKind::SymBoolType:
203 if (torch::is_symbool(obj.ptr())) {
204 return py::cast<c10::SymBool>(obj);
205 }
206 return py::cast<bool>(obj);
207 case TypeKind::NoneType:
208 if (!obj.is_none()) {
209 throw py::cast_error(
210 c10::str("Cannot cast ", py::str(obj), " to None"));
211 }
212 return {};
213 case TypeKind::BoolType:
214 if (torch::is_symbool(obj.ptr())) {
215 return py::cast<c10::SymBool>(obj).guard_bool(__FILE__, __LINE__);
216 }
217 if (THPVariable_Check(obj.ptr())) {
218 auto var = py::cast<autograd::Variable>(obj);
219 if (var.storage().device_type() == c10::kMeta) {
220 throw py::cast_error(
221 "cannot extract bool from tensor with meta storage");
222 }
223 }
224 return py::cast<bool>(obj);
225 case TypeKind::TupleType: {
226 py::tuple tuple = py::cast<py::tuple>(obj);
227 size_t tuple_size = tuple.size();
228 auto tuple_type = type->cast<TupleType>();
229 const auto& elem_types = tuple_type->elements();
230 if (elem_types.size() != tuple_size) {
231 throw py::cast_error(c10::str(
232 "Object ",
233 py::str(obj),
234 " had a different number of elements than type ",
235 type->repr_str()));
236 }
237 std::vector<IValue> values;
238 values.reserve(tuple_size);
239 for (const auto i : c10::irange(tuple_size)) {
240 values.push_back(toIValue(tuple[i], elem_types[i]));
241 }
242 return tuple_type->name()
243 ? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type)
244 : c10::ivalue::Tuple::create(std::move(values));
245 }
246 case TypeKind::UnionType: {
247 auto actual_type = toTypeInferredIValue(obj);
248 auto actual_type_ptr = actual_type.type();
249 auto union_type = type->expect<UnionType>();
250 if (!actual_type_ptr->isSubtypeOf(union_type)) {
251 throw py::cast_error(c10::str(
252 "Expected a member of ",
253 union_type->annotation_str(),
254 " but instead found type ",
255 actual_type.type()->annotation_str()));
256 }
257 return actual_type;
258 }
259 case TypeKind::StringType:
260 return ConstantString::create(py::cast<std::string>(obj));
261 case TypeKind::DeviceObjType: {
262 if (THPDevice_Check(obj.ptr())) {
263 auto device = reinterpret_cast<THPDevice*>(obj.ptr());
264 return device->device;
265 }
266 return c10::Device(py::cast<std::string>(obj.ptr()));
267 }
268 case TypeKind::StreamObjType: {
269 auto thp_stream = reinterpret_cast<THPStream*>(obj.ptr());
270 auto stream = c10::Stream::unpack3(
271 thp_stream->stream_id,
272 static_cast<c10::DeviceIndex>(thp_stream->device_index),
273 static_cast<c10::DeviceType>(thp_stream->device_type));
274 return stream;
275 }
276 case TypeKind::ListType: {
277 // If the object is a ScriptList, retrieve the c10::List
278 // instance inside it.
279 if (py::isinstance<ScriptList>(obj)) {
280 return py::cast<ScriptList>(obj).list_;
281 }
282
283 // If not (i.e. it is a regular Python list), make a new
284 // c10::List.
285 const auto& elem_type = type->expectRef<ListType>().getElementType();
286 switch (elem_type->kind()) {
287 // allows single int/float to be broadcasted to a fixed size list
288 case TypeKind::IntType:
289 if (!N || !py::isinstance<py::int_>(obj)) {
290 return IValue(py::cast<std::vector<int64_t>>(obj));
291 } else {
292 int64_t value = py::cast<int64_t>(obj);
293 c10::List<int64_t> repeated;
294 repeated.reserve(*N);
295 for (int i = 0; i < *N; ++i) {
296 repeated.push_back(value);
297 }
298 return repeated;
299 }
300 case TypeKind::SymIntType: {
301 bool is_symbolic = false;
302 for (auto it = obj.begin(); it != obj.end(); it++) {
303 auto elm = *it;
304 if (torch::is_symint(elm)) {
305 is_symbolic = true;
306 break;
307 }
308 }
309 if (is_symbolic) {
310 return listToIValue<c10::SymInt>(obj);
311 } else {
312 return listToIValue<int64_t>(obj);
313 }
314 }
315 case TypeKind::SymFloatType: {
316 bool is_symbolic = false;
317 for (auto it = obj.begin(); it != obj.end(); it++) {
318 auto elm = *it;
319 // TODO: what about SymInt conversion to SymFloat?
320 if (torch::is_symfloat(elm)) {
321 is_symbolic = true;
322 break;
323 }
324 }
325 if (is_symbolic) {
326 return listToIValue<c10::SymFloat>(obj);
327 } else {
328 return listToIValue<double>(obj);
329 }
330 }
331 case TypeKind::SymBoolType: {
332 bool is_symbolic = false;
333 for (auto it = obj.begin(); it != obj.end(); it++) {
334 auto elm = *it;
335 if (torch::is_symbool(elm)) {
336 is_symbolic = true;
337 break;
338 }
339 }
340 if (is_symbolic) {
341 return listToIValue<c10::SymBool>(obj);
342 } else {
343 return listToIValue<bool>(obj);
344 }
345 }
346 case TypeKind::FloatType:
347 if (!N || !py::isinstance<py::float_>(obj)) {
348 return IValue(py::cast<std::vector<double>>(obj));
349 } else {
350 double value = py::cast<double>(obj);
351 c10::List<double> repeated;
352 repeated.reserve(*N);
353 for (int i = 0; i < *N; ++i) {
354 repeated.push_back(value);
355 }
356 return repeated;
357 }
358 case TypeKind::BoolType:
359 return IValue(py::cast<std::vector<bool>>(obj));
360 case TypeKind::TensorType:
361 return IValue(py::cast<std::vector<at::Tensor>>(obj));
362 default:
363 return createGenericList(obj, elem_type);
364 }
365 }
366 case TypeKind::DictType: {
367 const auto& dict_type = type->expect<DictType>();
368
369 // If the object is a ScriptDict, retrieve the c10::Dict
370 // instance inside it.
371 try {
372 auto script_dict = py::cast<ScriptDict>(obj);
373 return script_dict.dict_;
374 } catch (py::cast_error& e) {
375 }
376
377 // If not (i.e. it is a regular Python dictionary), make a new
378 // c10::Dict.
379 return createGenericDict(
380 py::cast<py::dict>(obj),
381 dict_type->getKeyType(),
382 dict_type->getValueType());
383 }
384 case TypeKind::OptionalType: {
385 // check if it's a none obj since optional accepts NoneType
386 if (obj.is_none()) {
387 // check if it's a none obj since optional accepts NoneType
388 // return an IValue() to denote a NoneType
389 return {};
390 }
391 return toIValue(obj, type->expectRef<OptionalType>().getElementType(), N);
392 }
393 case TypeKind::ClassType: {
394 auto classType = type->expect<ClassType>();
395 auto object = py::cast<py::object>(obj);
396 if (auto mod = as_module(object)) {
397 // if obj is already a ScriptModule, just return its ivalue
398 return mod.value()._ivalue();
399 }
400
401 // Check if the obj is a ScriptObject.
402 if (auto script_obj = as_object(object)) {
403 return script_obj.value()._ivalue();
404 }
405
406 // otherwise is a normal class object, we create a fresh
407 // ivalue::Object to use from the py object.
408 // 1. create a bare ivalue
409 const size_t numAttrs = classType->numAttributes();
410 auto cu = classType->compilation_unit();
411 auto userObj = c10::ivalue::Object::create(
412 c10::StrongTypePtr(cu, classType), numAttrs);
413
414 // 2. copy all the contained types
415 for (const auto slot : c10::irange(numAttrs)) {
416 const auto& attrType = classType->getAttribute(slot);
417 const auto& attrName = classType->getAttributeName(slot);
418
419 if (!py::hasattr(obj, attrName.c_str())) {
420 throw py::cast_error(c10::str(
421 "Tried to cast object to type ",
422 type->repr_str(),
423 " but object",
424 " was missing attribute ",
425 attrName));
426 }
427
428 try {
429 const auto& contained = py::getattr(obj, attrName.c_str());
430 userObj->setSlot(slot, toIValue(contained, attrType));
431 } catch (std::exception& e) {
432 throw py::cast_error(c10::str(
433 "Could not cast attribute '",
434 attrName,
435 "' to type ",
436 attrType->repr_str(),
437 ": ",
438 e.what()));
439 }
440 }
441 return userObj;
442 }
443 case TypeKind::InterfaceType: {
444 auto interfaceType = type->expect<InterfaceType>();
445 // When converting an pyobj to an interface, we check if rhs
446 // is module or normal torchscript class, get the type and ivalue
447 // from them correspondingly.
448 c10::ClassTypePtr classType = nullptr;
449 IValue res;
450 if (auto mod = as_module(py::cast<py::object>(obj))) {
451 classType = mod.value().type();
452 res = mod.value()._ivalue();
453 } else if (auto object = as_object(py::cast<py::object>(obj))) {
454 classType = object.value().type();
455 res = object.value()._ivalue();
456 } else {
457 // We inspect the value to found the compiled TorchScript class
458 // and then create a ivalue::Object from that class type.
459 py::str qualified_name = py::module::import("torch._jit_internal")
460 .attr("_qualified_name")(obj.get_type());
461 auto pyCu = get_python_cu();
462 classType = pyCu->get_class(c10::QualifiedName(qualified_name));
463 if (!classType) {
464 throw std::runtime_error(c10::str(
465 "Assigning the object ",
466 py::str(obj),
467 " to an interface fails because the value is not "
468 "a TorchScript compatible type, did you forget to",
469 "turn it into a user defined TorchScript class?"));
470 }
471 res = toIValue(obj, classType);
472 }
473 // check if the classType conform with the interface or not
474 std::stringstream why_not;
475 if (!classType->isSubtypeOfExt(*interfaceType, &why_not)) {
476 throw py::cast_error(c10::str(
477 "Object of type ",
478 classType->repr_str(),
479 " is not compatible with interface ",
480 interfaceType->repr_str(),
481 "\n",
482 why_not.str()));
483 }
484 return res;
485 }
486 case TypeKind::NumberType: {
487 if (THPDtype_Check(obj.ptr())) {
488 auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
489 return static_cast<int64_t>(dtype->scalar_type);
490 }
491 if (THPQScheme_Check(obj.ptr())) {
492 auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
493 return static_cast<uint8_t>(qscheme->qscheme);
494 }
495 if (THPLayout_Check(obj.ptr())) {
496 auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
497 return static_cast<int8_t>(layout->layout);
498 }
499 if (py::isinstance<py::bool_>(obj)) {
500 return py::cast<bool>(obj);
501 } else if (py::isinstance<py::int_>(obj)) {
502 return py::cast<int64_t>(obj);
503 } else if (py::isinstance<py::float_>(obj)) {
504 return py::cast<double>(obj);
505 } else if (PyComplex_CheckExact(obj.ptr())) {
506 auto c_obj = py::cast<std::complex<double>>(obj.ptr());
507 return static_cast<c10::complex<double>>(c_obj);
508 } else if (torch::is_symint(obj)) {
509 return py::cast<c10::SymInt>(obj);
510 } else if (torch::is_symfloat(obj)) {
511 return py::cast<c10::SymFloat>(obj);
512 } else if (torch::is_symbool(obj)) {
513 return py::cast<c10::SymBool>(obj);
514 } else {
515 throw py::cast_error(
516 c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
517 }
518 }
519 case TypeKind::RRefType: {
520 #ifdef USE_RPC
521 return obj.cast<torch::distributed::rpc::PyRRef>().toIValue();
522 #else
523 AT_ERROR("RRef is only supported with the distributed package");
524 #endif
525 } break;
526 case TypeKind::PyObjectType: {
527 return c10::ivalue::ConcretePyObjectHolder::create(obj);
528 }
529 case TypeKind::CapsuleType: {
530 return IValue::make_capsule(py::cast<c10::Capsule>(obj).obj_ptr);
531 }
532 case TypeKind::FutureType: {
533 return obj.cast<std::shared_ptr<PythonFutureWrapper>>()->fut;
534 }
535 case TypeKind::AwaitType: {
536 return obj.cast<std::shared_ptr<PythonAwaitWrapper>>()->aw_;
537 }
538 case TypeKind::AnyType:
539 return toTypeInferredIValue(obj);
540 case TypeKind::QSchemeType: {
541 if (py::isinstance<py::int_>(obj)) {
542 return static_cast<at::QScheme>(py::cast<int64_t>(obj));
543 }
544 throw py::cast_error(
545 c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
546 }
547 case TypeKind::GeneratorType:
548 return py::cast<at::Generator>(obj);
549 case TypeKind::DynamicType:
550 case TypeKind::FunctionType:
551 case TypeKind::QuantizerType:
552 case TypeKind::VarType:
553 case TypeKind::AnyListType:
554 case TypeKind::AnyTupleType:
555 case TypeKind::AnyClassType:
556 case TypeKind::AnyEnumType:
557 break;
558 case TypeKind::EnumType:
559 EnumTypePtr enum_type = type->expect<EnumType>();
560 py::object py_obj = py::reinterpret_borrow<py::object>(obj);
561 std::string name = py::cast<std::string>(obj.attr("name"));
562 IValue value = toIValue(obj.attr("value"), enum_type->getValueType(), {});
563 auto enum_holder =
564 c10::make_intrusive<c10::ivalue::EnumHolder>(enum_type, name, value);
565 return IValue(enum_holder);
566 }
567 throw py::cast_error(c10::str(
568 "toIValue() cannot handle converting to type: ", type->repr_str()));
569 }
570
toPyObject(IValue ivalue)571 py::object toPyObject(IValue ivalue) {
572 if (ivalue.isNone()) {
573 return py::none();
574 } else if (ivalue.isTensor()) {
575 auto tensor = std::move(ivalue).toTensor();
576 if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
577 TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
578 auto py_tensor = py::cast(tensor);
579 if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) {
580 return py_tensor.attr("_wrapped_number");
581 }
582 auto scalar_type = tensor.scalar_type();
583 switch (scalar_type) {
584 case at::ScalarType::Bool:
585 return py::cast(*tensor.const_data_ptr<bool>());
586 case at::ScalarType::Long:
587 return py::cast(*tensor.const_data_ptr<int64_t>());
588 case at::ScalarType::Double:
589 return py::cast(*tensor.const_data_ptr<double>());
590 case at::ScalarType::ComplexDouble:
591 // TODO: https://github.com/pytorch/pytorch/issues/77134
592 return py::cast(static_cast<std::complex<double>>(
593 *tensor.const_data_ptr<c10::complex<double>>()));
594 default:
595 TORCH_CHECK(
596 false,
597 "Missing cases in 'toPyObject' wrapped number handling! Can't convert ",
598 scalar_type,
599 " to a Python object");
600 }
601 } else {
602 guardAgainstNamedTensor<at::Tensor>(tensor);
603 return py::cast(autograd::Variable(std::move(tensor)));
604 }
605 } else if (ivalue.isStorage()) {
606 return py::cast(std::move(ivalue).toStorage());
607 } else if (ivalue.isGenerator()) {
608 return py::cast(std::move(ivalue).toGenerator());
609 } else if (ivalue.isDouble()) {
610 return py::cast(std::move(ivalue).toDouble());
611 } else if (ivalue.isComplexDouble()) {
612 return py::cast(
613 static_cast<std::complex<double>>(std::move(ivalue).toComplexDouble()));
614 } else if (ivalue.isInt()) {
615 return py::cast(std::move(ivalue).toInt());
616 } else if (ivalue.isBool()) {
617 return py::cast(std::move(ivalue).toBool());
618 } else if (ivalue.isString()) {
619 if (getUTF8DecodingIgnore()) {
620 std::string s = std::move(ivalue).toStringRef();
621 PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore");
622 return py::reinterpret_steal<py::object>(pyObj);
623 } else {
624 return py::cast(std::move(ivalue).toStringRef());
625 }
626 } else if (ivalue.isList()) {
627 auto list = std::move(ivalue).toList();
628 py::list t{list.size()};
629 for (const auto i : c10::irange(list.size())) {
630 t[i] = toPyObject(IValue{list.get(i)});
631 }
632 return std::move(t);
633 } else if (ivalue.isTuple()) {
634 auto tuple = std::move(ivalue).toTuple();
635 const auto& elements = tuple->elements();
636
637 py::tuple t{elements.size()};
638 for (const auto i : c10::irange(elements.size())) {
639 t[i] = toPyObject(IValue{elements.at(i)});
640 }
641
642 // If we have a NamedTuple
643 if (tuple->type() && tuple->type()->schema() &&
644 !tuple->type()->schema()->name().empty()) {
645 auto unqualName = tuple->type()->name()->name();
646
647 std::vector<Argument> tuple_args = tuple->type()->schema()->arguments();
648
649 std::vector<pybind11::object> defaults;
650 auto it = std::find_if(
651 tuple_args.begin(), tuple_args.end(), [](const Argument& arg) {
652 return arg.default_value().has_value();
653 });
654 std::transform(
655 it,
656 tuple_args.end(),
657 std::back_inserter(defaults),
658 [](const Argument& arg) { return toPyObject(*arg.default_value()); });
659
660 std::vector<std::string> fieldNames =
661 fmap(tuple_args, [](const Argument& arg) { return arg.name(); });
662
663 return py::module::import("torch._jit_internal")
664 .attr("_create_named_tuple")(
665 t, unqualName, fieldNames, py::make_tuple(defaults));
666 } else {
667 return std::move(t);
668 }
669 } else if (ivalue.isDevice()) {
670 return py::cast(std::move(ivalue).toDevice());
671 } else if (ivalue.isStream()) {
672 return py::cast(std::move(ivalue).toStream());
673 } else if (ivalue.isGenericDict()) {
674 auto dict = std::move(ivalue).toGenericDict();
675 py::dict py_dict;
676 for (auto& pair : dict) {
677 py_dict[toPyObject(IValue{pair.key()})] =
678 toPyObject(IValue{pair.value()});
679 }
680 return std::move(py_dict);
681 } else if (ivalue.isRRef()) {
682 #ifdef USE_RPC
683 auto RRefPtr =
684 c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
685 std::move(ivalue).toRRef());
686 return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
687 #else
688 AT_ERROR("RRef is only supported with the distributed package");
689 #endif
690 } else if (ivalue.isObject()) {
691 const auto obj = std::move(ivalue).toObject();
692 if (obj->type()->is_module()) {
693 return py::cast(Module(obj));
694 }
695
696 auto pyCu = get_python_cu();
697 if (obj->name().find("__torch__.torch.classes") == 0) {
698 return py::cast(Object(obj));
699 }
700 const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
701 AT_ASSERT(classType, c10::str(obj->name(), " is not found."));
702 auto pyClass = getScriptedClassOrError(obj->type());
703 auto pyObj = pyClass.attr("__new__")(pyClass);
704
705 const auto numAttrs = classType->numAttributes();
706
707 for (const auto slot : c10::irange(numAttrs)) {
708 const auto& attrName = classType->getAttributeName(slot);
709 IValue v = obj->getSlot(slot);
710 py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v)));
711 }
712 return pyObj;
713 } else if (ivalue.isPyObject()) {
714 // return borrowed reference to ensure it correctly incref the underlying
715 // PyObject
716 return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
717 } else if (ivalue.isCapsule()) {
718 return py::cast(c10::Capsule(ivalue.toCapsule()));
719 } else if (ivalue.isFuture()) {
720 return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
721 } else if (ivalue.isAwait()) {
722 return py::cast(std::make_shared<PythonAwaitWrapper>(ivalue.toAwait()));
723 } else if (ivalue.isEnum()) {
724 auto enum_holder = ivalue.toEnumHolder();
725 auto py_class = getScriptedClassOrError(enum_holder->type());
726 return py_class.attr(enum_holder->name().c_str());
727 } else if (ivalue.isRRef()) {
728 #ifdef USE_RPC
729 return py::cast(torch::distributed::rpc::PyRRef(
730 c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
731 ivalue.toRRef())));
732 #else
733 TORCH_CHECK(false, "RRef is only supported with the distributed package");
734 #endif
735 } else if (ivalue.isSymInt()) {
736 return py::cast(std::move(ivalue).toSymInt());
737 } else if (ivalue.isSymFloat()) {
738 return py::cast(std::move(ivalue).toSymFloat());
739 } else if (ivalue.isSymBool()) {
740 return py::cast(std::move(ivalue).toSymBool());
741 } else {
742 AT_ERROR(
743 "Missing cases in 'toPyObject'! Can't convert ",
744 ivalue.tagKind(),
745 " to a Python object");
746 }
747 }
748
getOpWithStack(const std::vector<std::shared_ptr<Operator>> & operations,const py::args & args,const py::kwargs & kwargs)749 std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
750 const std::vector<std::shared_ptr<Operator>>& operations,
751 const py::args& args,
752 const py::kwargs& kwargs) {
753 Stack stack;
754 if (operations.size() == 1) {
755 std::shared_ptr<Operator> op = operations.at(0);
756 // Create a stack full of the arguments and keyword arguments.
757 stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt);
758
759 return std::make_pair(std::move(op), std::move(stack));
760 } else {
761 std::vector<schema_match_error> errors;
762 std::shared_ptr<Operator> found_op = nullptr;
763 for (const auto& op : operations) {
764 try {
765 stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt);
766 found_op = op;
767 break;
768 } catch (schema_match_error& error) {
769 errors.push_back(std::move(error));
770 }
771 }
772 if (!found_op) {
773 std::stringstream ss;
774 ss << "Overloaded torch operator invoked from Python failed to match any schema:\n";
775 for (const auto& err : errors) {
776 ss << err.what() << "\n\n";
777 }
778 throw std::runtime_error(ss.str());
779 }
780
781 return std::make_pair(std::move(found_op), std::move(stack));
782 }
783 }
784
785 // This function is used to check if the schema is valid for the given args and
786 // kwargs. It checks script object by checking wether the FakeScriptObject is
787 // an instance of the corresponding fake class for the actual class used in
788 // schema.
checkSchemaAllowFakeScriptObject(const FunctionSchema & schema,const py::args & args,const py::kwargs & kwargs)789 bool checkSchemaAllowFakeScriptObject(
790 const FunctionSchema& schema,
791 const py::args& args,
792 const py::kwargs& kwargs) {
793 bool match = false;
794 try {
795 match = matchSchemaAllowFakeScriptObject(schema, args, kwargs);
796 } catch (schema_match_error& error) {
797 throw std::runtime_error(error.what());
798 }
799 return match;
800 }
801
invokeOperatorFromPython(const std::vector<std::shared_ptr<Operator>> & operations,const py::args & args,const py::kwargs & kwargs,std::optional<c10::DispatchKey> dk)802 py::object invokeOperatorFromPython(
803 const std::vector<std::shared_ptr<Operator>>& operations,
804 const py::args& args,
805 const py::kwargs& kwargs,
806 std::optional<c10::DispatchKey> dk) {
807 auto [found_op, stack] = getOpWithStack(operations, args, kwargs);
808 {
809 pybind11::gil_scoped_release no_gil_guard;
810 if (dk) {
811 found_op->getOperationForDispatchKey (*dk)(stack);
812 } else {
813 found_op->getOperation()(stack);
814 }
815 }
816
817 return createPyObjectForStack(std::move(stack));
818 }
819
_maybe_handle_torch_function(const std::string & ns,const std::string & method_name,const std::string & overload_name,bool is_overload,const py::args & args,const py::kwargs & kwargs)820 std::optional<py::object> _maybe_handle_torch_function(
821 const std::string& ns,
822 const std::string& method_name,
823 const std::string& overload_name,
824 bool is_overload,
825 const py::args& args,
826 const py::kwargs& kwargs) {
827 std::vector<PyObject*> overloaded_args;
828 size_t total_arg_num = args.size() + kwargs.size();
829 for (const auto i : c10::irange(args.size())) {
830 is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
831 is_tensor_list_and_append_overloaded(
832 args[i].ptr(),
833 &overloaded_args,
834 static_cast<int>(total_arg_num),
835 false /* throw_error */);
836 }
837 // NB: for kwargs, we cannot guarantee the order of appending
838 // is the same as the argument order in operator's schema.
839 // This is suboptimal, but should be fine. Later when we have
840 // better schema matching and argument parsing, we could
841 // match the operator in `operations` first, then the order will
842 // be guaranteed.
843 for (auto item : kwargs) {
844 is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
845 is_tensor_list_and_append_overloaded(
846 item.second.ptr(),
847 &overloaded_args,
848 total_arg_num,
849 false /* throw_error */);
850 }
851 if (!overloaded_args.empty() || at::impl::torch_function_mode_enabled()) {
852 auto self_func = py::module::import("torch")
853 .attr("ops")
854 .attr(ns.c_str())
855 .attr(method_name.c_str());
856 if (is_overload) {
857 if (overload_name.empty()) {
858 self_func = self_func.attr("default");
859 } else {
860 self_func = self_func.attr(overload_name.c_str());
861 }
862 }
863 std::string module_name("torch.ops");
864 module_name.append(ns);
865 return {pybind11::reinterpret_steal<py::object>(
866 handle_torch_function_no_python_arg_parser(
867 overloaded_args,
868 args.ptr(),
869 kwargs.ptr(),
870 method_name.c_str(),
871 self_func.ptr(),
872 module_name.c_str()))};
873 }
874 return std::nullopt;
875 }
876
_get_operation_for_overload_or_packet(const std::vector<std::shared_ptr<Operator>> & operations,Symbol symbol,const py::args & args,const py::kwargs & kwargs,bool is_overload,std::optional<c10::DispatchKey> dk)877 py::object _get_operation_for_overload_or_packet(
878 const std::vector<std::shared_ptr<Operator>>& operations,
879 Symbol symbol,
880 const py::args& args,
881 const py::kwargs& kwargs,
882 bool is_overload,
883 std::optional<c10::DispatchKey> dk) {
884 std::string ns = symbol.ns().toUnqualString();
885 std::string method_name = symbol.toUnqualString();
886 std::string overload_name = operations[0]->schema().overload_name();
887 auto res = _maybe_handle_torch_function(
888 ns, method_name, overload_name, is_overload, args, kwargs);
889 auto torch_function_called = res.has_value();
890 return torch_function_called
891 ? *res
892 : invokeOperatorFromPython(operations, args, kwargs, dk);
893 }
894
895 } // namespace torch::jit
896