1 #include <torch/csrc/jit/python/python_sugared_value.h>
2
3 #include <ATen/core/interned_strings.h>
4 #include <c10/core/ScalarType.h>
5 #include <pybind11/pytypes.h>
6 #include <torch/csrc/Dtype.h>
7 #include <torch/csrc/Layout.h>
8 #include <torch/csrc/MemoryFormat.h>
9 #include <torch/csrc/jit/frontend/schema_matching.h>
10 #include <torch/csrc/jit/python/module_python.h>
11 #include <torch/csrc/utils/pybind.h>
12 #include <climits>
13 #include <memory>
14 #include <sstream>
15 #include <string>
16 #include <tuple>
17 #include <vector>
18
19 #include <Python.h>
20
21 namespace torch::jit {
22
typeString(py::handle h)23 std::string typeString(py::handle h) {
24 return py::str(h.get_type().attr("__name__"));
25 }
26
as_function(const py::object & obj)27 std::optional<StrongFunctionPtr> as_function(const py::object& obj) {
28 if (py::isinstance<StrongFunctionPtr>(obj)) {
29 return py::cast<StrongFunctionPtr>(obj);
30 }
31 return std::nullopt;
32 }
33
getSchema(const size_t n_args,const size_t n_binders,const SourceRange & loc)34 FunctionSchema PythonValue::getSchema(
35 const size_t n_args,
36 const size_t n_binders,
37 const SourceRange& loc) {
38 auto annotations = py::module::import("torch.jit.annotations");
39 const auto callable = moduleSelf_ ? py::getattr(self, "original_fn") : self;
40
41 // Make sure the function is not a class instantiation (e.g. `Exception()`)
42 annotations.attr("check_fn")(callable, loc);
43 auto is_vararg = py::cast<bool>(annotations.attr("is_vararg")(callable));
44
45 auto signature = annotations.attr("get_signature")(
46 callable, rcb ? *rcb : py::none(), loc, bool(moduleSelf_));
47 std::vector<Argument> args, rets;
48
49 auto py_param_names = annotations.attr("get_param_names")(callable, n_args);
50 auto param_names = py::cast<std::vector<std::string>>(py_param_names);
51 auto names_it = param_names.begin();
52 if (moduleSelf_) {
53 if (param_names.empty()) {
54 throw(
55 ErrorReport(loc)
56 << "Non-static method does not have a self argument");
57 }
58
59 // If there is a `self` parameter on the callable, skip it on the names list
60 args.emplace_back(Argument(*names_it, moduleSelf_->type(), {}, {}, false));
61 ++names_it;
62 }
63 if (signature.is_none()) {
64 // No type signature was provided on the callable, so make a default
65 // signature where each argument is typed as a Tensor
66 for (; names_it != param_names.end(); ++names_it) {
67 args.emplace_back(
68 /*name=*/*names_it,
69 /*type=*/TensorType::get(),
70 /*N=*/std::nullopt,
71 /*default_value=*/std::nullopt,
72 /*kwarg_only=*/false);
73 }
74
75 // Use as many outputs as are requested to make the return type
76 TypePtr ret_type = TensorType::get();
77 if (n_binders == 0) {
78 ret_type = NoneType::get();
79 } else if (n_binders > 1) {
80 std::vector<TypePtr> tuple_values(n_binders, ret_type);
81 ret_type = TupleType::create(std::move(tuple_values));
82 }
83 rets.emplace_back(Argument("0", ret_type, {}, {}, false));
84 } else {
85 // Use the provided type signature
86 auto [arg_types, ret_type] =
87 py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
88
89 // arg_types does not include self but param_names does, so adjust for that
90 // if needed
91 TORCH_INTERNAL_ASSERT(
92 arg_types.size() == param_names.size() - (moduleSelf_ ? 1 : 0));
93
94 auto types_it = arg_types.begin();
95 for (; types_it != arg_types.end(); ++types_it, ++names_it) {
96 args.emplace_back(
97 /*name=*/*names_it,
98 /*type=*/std::move(*types_it),
99 /*N=*/std::nullopt,
100 /*default_value=*/std::nullopt,
101 /*kwarg_only=*/false);
102 }
103 rets.push_back(Argument("0", ret_type, {}, {}, false));
104 }
105
106 std::string name;
107 if (py::hasattr(self, "__qualname__")) {
108 // Use the qualified name if possible
109 name = py::str(py::getattr(self, "__qualname__"));
110 } else if (py::hasattr(self, "__name__")) {
111 name = py::str(py::getattr(self, "__name__"));
112 }
113 return FunctionSchema(name, "", std::move(args), std::move(rets), is_vararg);
114 }
115
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)116 std::shared_ptr<SugaredValue> PythonValue::call(
117 const SourceRange& loc,
118 GraphFunction& m,
119 at::ArrayRef<NamedValue> args,
120 at::ArrayRef<NamedValue> kwargs,
121 size_t n_binders) {
122 std::vector<NamedValue> argsWithSelf;
123 if (moduleSelf_) {
124 argsWithSelf.emplace_back("self", moduleSelf_);
125 }
126 argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end());
127
128 auto schema = getSchema(argsWithSelf.size(), n_binders, loc);
129 auto inputs = toValues(*m.graph(), argsWithSelf);
130
131 MatchedSchema matched_schema =
132 matchSchema(schema, loc, *m.graph(), argsWithSelf, kwargs);
133
134 // If if a function is marked as dropped,
135 // we throw an exception if it is invoked.
136 if (py::cast<bool>(py::module::import("torch._jit_internal")
137 .attr("should_drop")(self))) {
138 auto g = m.graph();
139 auto err_msg = insertConstant(
140 *g,
141 IValue(
142 "This Python function is annotated to be ignored and cannot be run"));
143 g->insert(prim::RaiseException, {err_msg}, {}, loc);
144 return std::make_shared<SimpleValue>(
145 g->insertNode(g->createUninitialized(matched_schema.return_types.at(0)))
146 ->output());
147 }
148
149 // Release the function object so we can wrap it in a PythonOp
150 py::object func = self;
151 std::string cconv(inputs.size(), 'd');
152 Node* new_node = m.graph()->insertNode(
153 m.graph()->createPythonOp(THPObjectPtr(func.release().ptr()), cconv, {}));
154
155 new_node->setSourceRange(loc);
156 for (auto& i : matched_schema.inputs)
157 new_node->addInput(i);
158
159 Value* output =
160 new_node->addOutput()->setType(matched_schema.return_types.at(0));
161 return std::make_shared<SimpleValue>(output);
162 }
163
kind() const164 std::string PythonValue::kind() const {
165 std::stringstream ss;
166 ss << "python value of type '" << typeString(self) << "'";
167 return ss.str();
168 }
169
asTuple(const SourceRange & loc,GraphFunction & m,const std::optional<size_t> & size_hint)170 std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
171 const SourceRange& loc,
172 GraphFunction& m,
173 const std::optional<size_t>& size_hint) {
174 std::stringstream ss;
175 ss << kind() << " cannot be used as a tuple";
176 checkForAddToConstantsError(ss);
177 throw(ErrorReport(loc) << ss.str());
178 }
179
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)180 std::shared_ptr<SugaredValue> PythonValue::attr(
181 const SourceRange& loc,
182 GraphFunction& m,
183 const std::string& field) {
184 std::stringstream ss;
185 ss << "attribute lookup is not defined on " << kind();
186 checkForAddToConstantsError(ss);
187 throw(ErrorReport(loc) << ss.str());
188 }
189
getattr(const SourceRange & loc,const std::string & name)190 py::object PythonValue::getattr(
191 const SourceRange& loc,
192 const std::string& name) {
193 try {
194 return py::getattr(self, name.c_str());
195 } catch (py::error_already_set& e) {
196 throw(ErrorReport(loc) << "object has no attribute " << name);
197 }
198 }
199
checkForAddToConstantsError(std::stringstream & ss)200 void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
201 auto nn = py::module::import("torch.nn");
202 if (py::isinstance(self, nn.attr("ModuleList")) ||
203 py::isinstance(self, nn.attr("Sequential"))) {
204 ss << ". Did you forget to add it to __constants__? ";
205 }
206 }
207
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)208 std::shared_ptr<SugaredValue> PythonModuleValue::attr(
209 const SourceRange& loc,
210 GraphFunction& m,
211 const std::string& field) {
212 py::object member = getattr(loc, field);
213 // note: is_constant = true because we consider that global properties
214 // on modules like math.pi or torch.float to be constants
215 // even though it is possible, though rare, for someone to mutate them
216 return toSugaredValue(member, m, loc, /*is_constant=*/true);
217 }
218
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)219 std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
220 const SourceRange& loc,
221 GraphFunction& m,
222 const std::string& field) {
223 // List of all the cuda operators which are supported in JIT
224 const std::unordered_set<std::string> cuda_ops = {
225 "current_stream",
226 "default_stream",
227 "current_device",
228 "_exchange_device",
229 "_maybe_exchange_device",
230 "set_device",
231 "device_index",
232 "device_count",
233 "set_stream",
234 "synchronize"};
235
236 if (cuda_ops.find(field) != cuda_ops.end()) {
237 // Both current_device and set_device API's are a part of c10::cuda
238 // namespace. Hence, to resolve the conflict for jit, we append _ to both
239 // these APIs.
240 if (field == "current_device" || field == "set_device") {
241 return std::make_shared<BuiltinFunction>(
242 Symbol::cuda("_" + field), std::nullopt);
243 } else {
244 return std::make_shared<BuiltinFunction>(
245 Symbol::cuda(field), std::nullopt);
246 }
247 }
248
249 if (field == "Stream" || field == "Event") {
250 auto class_type = getCustomClass("__torch__.torch.classes.cuda." + field);
251 return std::make_shared<ClassValue>(class_type);
252 }
253
254 py::object member = getattr(loc, field);
255 // note: is_constant = true because we consider that global properties
256 // on modules like math.pi or torch.float to be constants
257 // even though it is possible, though rare, for someone to mutate them
258 return toSugaredValue(member, m, loc, /*is_constant=*/true);
259 }
260
asValue(const SourceRange & loc,GraphFunction & m)261 Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) {
262 return self_;
263 }
264
asTupleValue(const SourceRange & loc,GraphFunction & m)265 SugaredValuePtr ModuleValue::asTupleValue(
266 const SourceRange& loc,
267 GraphFunction& m) {
268 if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
269 auto dict = getSugaredDict(loc, m);
270 auto mods = dict->getModules();
271 return mods;
272 }
273 throw(
274 ErrorReport(loc)
275 << "Only ModuleList or Sequential modules can be used as tuple");
276 }
277
areAllSubmodulesSubtypeOf(const TypePtr & ty,std::ostream * why_not) const278 bool ModuleValue::areAllSubmodulesSubtypeOf(
279 const TypePtr& ty,
280 std::ostream* why_not) const {
281 const auto& self_type = concreteType_->getJitType()->expect<ClassType>();
282 for (size_t i = 0; i < self_type->numAttributes(); ++i) {
283 const auto& attr_type = self_type->getAttribute(i);
284 if (attr_type->is_module()) {
285 std::stringstream ss;
286 if (!attr_type->isSubtypeOfExt(ty, &ss)) {
287 if (why_not) {
288 *why_not << "Attribute " << self_type->getAttributeName(i)
289 << " is not of annotated type " << ty->annotation_str()
290 << ": " << ss.str();
291 }
292
293 return false;
294 }
295 }
296 }
297
298 return true;
299 }
300
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)301 SugaredValuePtr ModuleValue::getitem(
302 const SourceRange& loc,
303 GraphFunction& m,
304 Value* idx,
305 TypePtr type_hint) {
306 if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
307 if (type_hint) {
308 // Check that all submodules comply with the type hint.
309 std::stringstream ss;
310 if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
311 throw(ErrorReport(loc) << ss.str());
312 }
313
314 // Emit a prim::ModuleContainerIndex operator. This is needed because
315 // it's difficult to construct a list in the graph representing the
316 // ModuleList and use aten::__getitem__ ops to index into it because
317 // any call to ModuleList.setitem would invalidate that emitted list.
318 auto graph = m.graph();
319 auto* getitem_node = graph->insertNode(
320 graph->create(prim::ModuleContainerIndex, {self_, idx}));
321 getitem_node->output(0)->setType(type_hint);
322 return std::make_shared<SimpleValue>(getitem_node->output(0));
323 } else {
324 return getSugaredDict(loc, m)->getModules()->getitem(
325 loc, m, idx, type_hint);
326 }
327 } else if (
328 concreteType_->getIterableModuleKind() == IterableModuleKind::PARAMLIST) {
329 return getSugaredNamedParameterList(loc, m)->getModules()->getitem(
330 loc, m, idx, type_hint);
331 } else if (
332 concreteType_->getIterableModuleKind() == IterableModuleKind::DICT ||
333 concreteType_->getIterableModuleKind() == IterableModuleKind::PARAMDICT) {
334 if (auto ivalue = toIValue(idx)) {
335 std::shared_ptr<SugaredDict> sd;
336 if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
337 sd = getSugaredDict(loc, m);
338 } else if (
339 concreteType_->getIterableModuleKind() ==
340 IterableModuleKind::PARAMDICT) {
341 sd = getSugaredNamedParameterDict(loc, m);
342 }
343 auto idx_str = ivalue->toStringRef();
344 auto keys_iter = sd->keys_;
345 auto module_values_iter = sd->modules_;
346 for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
347 auto key = keys_iter->tup_.at(i);
348 auto key_str = toIValue(key->asValue(loc, m))->toStringRef();
349 if (key_str == idx_str) {
350 return module_values_iter->tup_.at(i);
351 }
352 }
353 throw(ErrorReport(loc) << "Key Error, " << idx_str);
354 } else if (type_hint) {
355 // Check that all submodules comply with the type hint.
356 std::stringstream ss;
357 if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
358 throw(ErrorReport(loc) << ss.str());
359 }
360
361 // Emit a prim::ModuleContainerIndex operator. This is needed because
362 // it's difficult to construct a dict in the graph representing the
363 // ModuleDict and use aten::__getitem__ ops to index into it because
364 // any call to ModuleDict.setAttr would invalidate that emitted dict.
365 auto graph = m.graph();
366 auto* getitem_node = graph->insertNode(
367 graph->create(prim::ModuleContainerIndex, {self_, idx}));
368 getitem_node->output(0)->setType(type_hint);
369 return std::make_shared<SimpleValue>(getitem_node->output(0));
370 }
371 throw(
372 ErrorReport(loc)
373 << "Unable to extract string literal index. "
374 << "ModuleDict indexing is only supported with string literals. "
375 << "For example, 'i = \"a\"; self.layers[i](x)' will fail because i is not a literal. "
376 << "Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): out = v(inp)'");
377 }
378 throw(
379 ErrorReport(loc)
380 << "Only ModuleList, Sequential, ModuleDict, "
381 << "ParameterList, and ParameterDict modules are subscriptable");
382 }
383
checkInterface(const SourceRange & loc,GraphFunction & m,const std::shared_ptr<ModuleValue> & self,const std::string & field)384 void checkInterface(
385 const SourceRange& loc,
386 GraphFunction& m,
387 const std::shared_ptr<ModuleValue>& self,
388 const std::string& field) {
389 if (self->asValue(loc, m)->type()->cast<InterfaceType>()) {
390 throw(
391 ErrorReport(loc)
392 << "Could not compile " << field
393 << "() because module is an interface type. Please file issue.");
394 }
395 }
396
recurseThroughNestedModules(const SourceRange & loc,GraphFunction & m,std::vector<SugaredValuePtr> & keys,std::vector<SugaredValuePtr> & values,std::shared_ptr<ModuleValue> & self,const std::string & prefix,const std::string & field)397 void recurseThroughNestedModules(
398 const SourceRange& loc,
399 GraphFunction& m,
400 std::vector<SugaredValuePtr>& keys,
401 std::vector<SugaredValuePtr>& values,
402 std::shared_ptr<ModuleValue>& self,
403 const std::string& prefix,
404 const std::string& field) {
405 auto prefix_value =
406 std::make_shared<SimpleValue>(insertConstant(*m.graph(), prefix));
407
408 keys.push_back(prefix_value);
409 values.push_back(self);
410
411 checkInterface(loc, m, self, field);
412 auto module_dict = self->getSugaredDict(loc, m);
413 auto keys_iter = module_dict->keys_;
414 auto module_values_iter = module_dict->modules_;
415 for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
416 std::shared_ptr<SugaredValue> module_sugared_value =
417 module_values_iter->tup_.at(i);
418 auto module_value =
419 std::dynamic_pointer_cast<ModuleValue>(module_sugared_value);
420
421 auto keys_value = keys_iter->tup_.at(i);
422 auto key_string = toIValue(keys_value->asValue(loc, m))->toStringRef();
423 std::string submodule_prefix = prefix;
424 if (!prefix.empty()) {
425 submodule_prefix = prefix + ".";
426 }
427 submodule_prefix += key_string;
428 recurseThroughNestedModules(
429 loc, m, keys, values, module_value, submodule_prefix, field);
430 };
431 }
432
getSugaredNamedBufferDict(const SourceRange & loc,GraphFunction & m)433 std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
434 const SourceRange& loc,
435 GraphFunction& m) {
436 std::vector<std::string> paramNames;
437 std::vector<SugaredValuePtr> values;
438
439 const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
440 for (size_t i = 0; i < selfType->numAttributes(); ++i) {
441 if (selfType->is_buffer(i)) {
442 paramNames.push_back(selfType->getAttributeName(i));
443 }
444 }
445
446 std::vector<SugaredValuePtr> keys;
447 for (const auto& name : paramNames) {
448 auto name_v =
449 std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
450 m.graph()->insertGetAttr(self_, name);
451 values.push_back(tryGetAttr(loc, m, name));
452 keys.push_back(name_v);
453 }
454
455 return std::make_shared<SugaredDict>(
456 std::make_shared<ModuleValue>(self_, concreteType_),
457 std::make_shared<SugaredTupleValue>(keys),
458 std::make_shared<SugaredTupleValue>(values));
459 }
460
getSugaredNamedParameterList(const SourceRange & loc,GraphFunction & m)461 std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedParameterList(
462 const SourceRange& loc,
463 GraphFunction& m) {
464 std::vector<std::string> paramNames;
465 std::vector<SugaredValuePtr> values;
466
467 const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
468 for (size_t i = 0; i < selfType->numAttributes(); ++i) {
469 if (selfType->is_parameter(i)) {
470 paramNames.push_back(selfType->getAttributeName(i));
471 }
472 }
473
474 std::vector<SugaredValuePtr> keys;
475 for (const auto& name : paramNames) {
476 auto name_v =
477 std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
478 m.graph()->insertGetAttr(self_, name);
479 values.push_back(tryGetAttr(loc, m, name));
480 keys.push_back(name_v);
481 }
482
483 return std::make_shared<SugaredDict>(
484 std::make_shared<ModuleValue>(self_, concreteType_),
485 std::make_shared<SugaredTupleValue>(keys),
486 std::make_shared<SugaredTupleValue>(values));
487 }
488
getSugaredDict(const SourceRange & loc,GraphFunction & m)489 std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
490 const SourceRange& loc,
491 GraphFunction& m) {
492 std::vector<std::string> submoduleNames;
493 const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
494 for (size_t i = 0; i < selfType->numAttributes(); ++i) {
495 const auto& attrType = selfType->getAttribute(i);
496 if (attrType->is_module()) {
497 submoduleNames.push_back(selfType->getAttributeName(i));
498 }
499 }
500
501 std::vector<SugaredValuePtr> keys;
502 std::vector<SugaredValuePtr> values;
503 for (const auto& name : submoduleNames) {
504 auto name_v =
505 std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
506 Value* module_v = m.graph()->insertGetAttr(self_, name);
507 auto mod_v = std::make_shared<ModuleValue>(
508 module_v, concreteType_->findSubmoduleConcreteType(name));
509
510 keys.push_back(name_v);
511 values.push_back(mod_v);
512 }
513
514 return std::make_shared<SugaredDict>(
515 std::make_shared<ModuleValue>(self_, concreteType_),
516 std::make_shared<SugaredTupleValue>(keys),
517 std::make_shared<SugaredTupleValue>(values));
518 }
519
getSugaredNamedParameterDict(const SourceRange & loc,GraphFunction & m)520 std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedParameterDict(
521 const SourceRange& loc,
522 GraphFunction& m) {
523 std::vector<std::string> paramNames;
524 const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
525 for (size_t i = 0; i < selfType->numAttributes(); ++i) {
526 if (selfType->is_parameter(i)) {
527 paramNames.push_back(selfType->getAttributeName(i));
528 }
529 }
530
531 std::vector<SugaredValuePtr> keys;
532 std::vector<SugaredValuePtr> values;
533 for (const auto& name : paramNames) {
534 auto name_v =
535 std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
536 m.graph()->insertGetAttr(self_, name);
537 auto val = tryGetAttr(loc, m, name);
538 TORCH_INTERNAL_ASSERT(val != nullptr, "Could not find attribute ", name);
539 values.push_back(val);
540 keys.push_back(name_v);
541 }
542
543 return std::make_shared<SugaredDict>(
544 std::make_shared<ModuleValue>(self_, concreteType_),
545 std::make_shared<SugaredTupleValue>(keys),
546 std::make_shared<SugaredTupleValue>(values));
547 }
548
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)549 std::shared_ptr<SugaredValue> SugaredDict::attr(
550 const SourceRange& loc,
551 GraphFunction& m,
552 const std::string& field) {
553 // Recursive compilation does not maintain module aliasing,
554 // so we do not add uniqueness checks on
555 // "children"/"named_children"/"modules"/"named_modules"
556 checkInterface(loc, m, self_, field);
557 if (field == "keys") {
558 return std::make_shared<ModuleDictMethod>(keys_, "keys");
559 } else if (field == "values" || field == "children") {
560 return std::make_shared<ModuleDictMethod>(modules_, field);
561 } else if (
562 field == "items" || field == "named_children" ||
563 field == "named_buffers") {
564 auto iterator = std::make_shared<IterableTree>();
565 iterator->addChild(loc, m, keys_);
566 iterator->addChild(loc, m, modules_);
567 return std::make_shared<ModuleDictMethod>(iterator, field);
568 } else if (field == "named_modules" || field == "modules") {
569 std::vector<SugaredValuePtr> keys;
570 std::vector<SugaredValuePtr> values;
571 recurseThroughNestedModules(loc, m, keys, values, self_, "", field);
572 if (field == "modules") {
573 return std::make_shared<ModuleDictMethod>(
574 std::make_shared<SugaredTupleValue>(values), field);
575 } else {
576 auto iterator = std::make_shared<IterableTree>();
577 iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(keys));
578 iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(values));
579 return std::make_shared<ModuleDictMethod>(iterator, field);
580 }
581 }
582 TORCH_INTERNAL_ASSERT(false);
583 }
584
createSugaredEnumClassFromObj(const py::object & obj,GraphFunction & m,const SourceRange & loc)585 std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
586 const py::object& obj,
587 GraphFunction& m,
588 const SourceRange& loc) {
589 auto annotation_type = py::module::import("torch.jit.annotations")
590 .attr("try_ann_to_type")(obj, loc);
591 TORCH_INTERNAL_ASSERT(!annotation_type.is_none());
592 auto type = py::cast<TypePtr>(annotation_type);
593 auto enum_type = type->expect<EnumType>();
594 return std::make_shared<SugaredEnumClass>(enum_type);
595 }
596
597 // helper function for instantiating a SugaredValue from an IValue
toSugaredValue(const IValue & v,GraphFunction & m,const SourceRange & loc)598 std::shared_ptr<SugaredValue> toSugaredValue(
599 const IValue& v,
600 GraphFunction& m,
601 const SourceRange& loc) {
602 if (v.isTuple()) {
603 auto tp = v.toTuple();
604 std::vector<Value*> values;
605 values.reserve(tp->elements().size());
606 for (const auto& e : tp->elements()) {
607 values.push_back(toSugaredValue(e, m, loc)->asValue(loc, m));
608 }
609 return toSimple(
610 m.graph()->insertNode(m.graph()->createTuple(values))->output());
611 } else {
612 return toSimple(m.graph()->insertConstant(v, loc));
613 }
614 }
615
616 // This method controls how we desugar attribute lookups on ScriptModules
tryGetAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)617 std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
618 const SourceRange& loc,
619 GraphFunction& m,
620 const std::string& field) {
621 // 1. Look inside Module object for the field.
622 const auto& selfType_ = concreteType_->getJitType();
623 if (selfType_->cast<InterfaceType>()) {
624 return std::make_shared<SimpleValue>(self_)->attr(loc, m, field);
625 }
626
627 const auto& selfType = selfType_->expect<ClassType>();
628
629 if (selfType->hasAttribute(field) &&
630 selfType->getAttribute(field)->is_module()) {
631 // ...if it's a submodule, return it as a new ModuleValue.
632 if (const auto submoduleConcreteType =
633 concreteType_->findSubmoduleConcreteType(field)) {
634 return std::make_shared<ModuleValue>(
635 m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
636 }
637
638 return std::make_shared<ModuleValue>(
639 m.graph()->insertGetAttr(self_, field),
640 ConcreteModuleType::fromJitType(selfType->getAttribute(field)));
641 } else if (selfType->hasAttribute(field) || selfType->findMethod(field)) {
642 // ...otherwise, methods, parameters, attributes, and buffers are all
643 // first class so they get returned as SimpleValues
644 return std::make_shared<SimpleValue>(self_)->attr(loc, m, field);
645 } else if (selfType->hasConstant(field)) {
646 auto v = selfType->getConstant(field);
647 return toSugaredValue(v, m, loc);
648 }
649
650 // 2. Special case: for module dicts we manually desugar items(), keys(),
651 // values() calls into the appropriate method.
652 if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
653 if (field == "items" || field == "keys" || field == "values") {
654 return getSugaredDict(loc, m)->attr(loc, m, field);
655 }
656 }
657
658 if (field == "named_modules" || field == "modules" || field == "children" ||
659 field == "named_children") {
660 return getSugaredDict(loc, m)->attr(loc, m, field);
661 }
662
663 if (field == "named_buffers") {
664 return getSugaredNamedBufferDict(loc, m)->attr(loc, m, field);
665 }
666
667 // 3. Check if this is the name of an overloaded method.
668
669 // This can also be a call to a non-script module, or a plain
670 // python method. If so return this as a python value.
671 if (const auto overloads = concreteType_->findOverloads(field)) {
672 return std::make_shared<MethodValue>(self_, *overloads);
673 }
674
675 // 4. Check if it's a function attribute.
676 if (const auto fnAttr = concreteType_->findFunctionAttribute(field)) {
677 return std::make_shared<FunctionValue>(*fnAttr);
678 } else if (const auto builtin = concreteType_->findBuiltinFunction(field)) {
679 return std::make_shared<BuiltinFunction>(*builtin, /*self=*/std::nullopt);
680 }
681
682 // 5. Check if it's an attribute of the original Python class that this
683 // ScriptModule was derived from. The only class attributes we handle are
684 // methods.
685 const auto maybePyClass = concreteType_->getPyClass();
686 if (!maybePyClass) {
687 // ConcreteType doesn't always have an originating Python class, e.g. if it
688 // was derived from a serialized ScriptModule. In this case, we've exhausted
689 // our options for attr lookup.
690 return nullptr;
691 }
692 py::object unboundMethod = py::getattr(
693 *maybePyClass, field.c_str(), pybind11::cast<pybind11::none>(Py_None));
694
695 if (py::isinstance<py::function>(unboundMethod)) {
696 bool isStaticFn =
697 py::cast<bool>(py::module::import("torch._jit_internal")
698 .attr("is_static_fn")(*maybePyClass, field.c_str()));
699 if (isStaticFn) {
700 // Functions within the module annotated with @staticmethod do not need
701 // binding.
702 py::object staticFn =
703 py::module::import("torch._jit_internal")
704 .attr("get_static_fn")(*maybePyClass, field.c_str());
705 return toSugaredValue(staticFn, m, loc);
706 }
707 // For Python methods that we're trying to call directly, we need to bind
708 // the method to a self. (see the documentation for lazy_bind in Python for
709 // more info).
710 bool isIgnoredFn =
711 py::cast<bool>(py::module::import("torch._jit_internal")
712 .attr("is_ignored_fn")(unboundMethod));
713 if (isIgnoredFn) {
714 // Create a generated ScriptModule type with module_ set as cpp_module
715 auto boundMethod = py::module::import("torch.jit._recursive")
716 .attr("lazy_bind")(concreteType_, unboundMethod);
717 TORCH_CHECK(py::isinstance<py::function>(boundMethod));
718 auto rcb =
719 py::module::import("torch._jit_internal")
720 .attr("createResolutionCallbackFromClosure")(unboundMethod);
721 return std::make_shared<PythonValue>(boundMethod, rcb, self_);
722 }
723
724 // If we reach here, it's because this is a "normal" method that just hasn't
725 // been compiled yet (directly exported methods would have been returned by
726 // step 1). Just compile it.
727 auto stub =
728 py::module::import("torch.jit._recursive")
729 .attr("compile_unbound_method")(concreteType_, unboundMethod);
730 TORCH_INTERNAL_ASSERT(!stub.is_none());
731 // Look up the attribute again, it will be available as a compiled method.
732 return attr(loc, m, field);
733 }
734
735 return nullptr;
736 }
737
hasAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)738 bool ModuleValue::hasAttr(
739 const SourceRange& loc,
740 GraphFunction& m,
741 const std::string& field) {
742 return tryGetAttr(loc, m, field) != nullptr;
743 }
744
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)745 std::shared_ptr<SugaredValue> ModuleValue::call(
746 const SourceRange& loc,
747 GraphFunction& caller,
748 at::ArrayRef<NamedValue> args,
749 at::ArrayRef<NamedValue> kwargs,
750 size_t n_binders) {
751 c10::ClassTypePtr class_type = concreteType_->getJitType()->cast<ClassType>();
752 bool have_pre_hooks = class_type && !class_type->getForwardPreHooks().empty();
753 bool have_hooks = class_type && !class_type->getForwardHooks().empty();
754
755 std::vector<Value*> arg_values;
756 std::vector<NamedValue> pre_hook_result;
757 Value* forward_input = nullptr;
758 std::shared_ptr<Graph> calling_graph = caller.graph();
759
760 if (have_pre_hooks || have_hooks) {
761 // convert forward args into tuple for forward hooks
762 // (the input of eager hooks are always tuples)
763 for (const auto& sv : args) {
764 arg_values.push_back(sv.value(*calling_graph));
765 }
766 forward_input =
767 calling_graph->insertNode(calling_graph->createTuple(arg_values))
768 ->output();
769 }
770
771 // call pre_hooks
772 if (have_pre_hooks) {
773 for (const auto& hook : class_type->getForwardPreHooks()) {
774 TORCH_INTERNAL_ASSERT(forward_input != nullptr);
775 Value* pre_hook_output =
776 FunctionValue(hook)
777 .call(
778 loc,
779 caller,
780 {NamedValue(self_), NamedValue(forward_input)},
781 kwargs,
782 n_binders)
783 ->asValue(loc, caller);
784 if (pre_hook_output->type() != NoneType::get()) {
785 if (pre_hook_output->type()->kind() != TypeKind::TupleType) {
786 pre_hook_output =
787 calling_graph
788 ->insertNode(calling_graph->createTuple({pre_hook_output}))
789 ->output();
790 }
791 forward_input = pre_hook_output;
792 }
793 }
794 // de-tuple pre_hook output for forward
795 at::ArrayRef<Value*> output_nodes =
796 calling_graph
797 ->insertNode(calling_graph->createTupleUnpack(forward_input))
798 ->outputs();
799 for (auto& output_node : output_nodes) {
800 pre_hook_result.emplace_back(output_node);
801 }
802 if (!args.empty()) { // only replace input if it existed
803 args = pre_hook_result;
804 }
805 }
806
807 // call forward
808 std::shared_ptr<SugaredValue> forwardSV =
809 attr(loc, caller, "forward")->call(loc, caller, args, kwargs, n_binders);
810 Value* forward_output = forwardSV->asValue(loc, caller);
811
812 // call hooks
813 if (have_hooks) {
814 for (const auto& hook : class_type->getForwardHooks()) {
815 Value* forward_hook_output = FunctionValue(hook)
816 .call(
817 loc,
818 caller,
819 {NamedValue(self_),
820 NamedValue(forward_input),
821 NamedValue(forward_output)},
822 kwargs,
823 n_binders)
824 ->asValue(loc, caller);
825 if (forward_hook_output->type() != NoneType::get()) {
826 forward_output = forward_hook_output;
827 }
828 }
829 }
830
831 return std::make_shared<SimpleValue>(forward_output);
832 }
833
834 // This method controls how we desugar attribute lookups on ScriptModules.
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)835 std::shared_ptr<SugaredValue> ModuleValue::attr(
836 const SourceRange& loc,
837 GraphFunction& m,
838 const std::string& field) {
839 if (auto attr = tryGetAttr(loc, m, field)) {
840 return attr;
841 }
842
843 // Check if it's a property.
844 auto prop =
845 concreteType_->getJitType()->expectRef<ClassType>().getProperty(field);
846 if (prop) {
847 return MethodValue(self_, prop->getter->name())
848 .call(loc, m, {}, {}, /*n_binders=*/1);
849 }
850
851 // We don't define this attr. Bailout with a hint to the user.
852 std::string hint;
853 if (auto failureReason = concreteType_->findFailedAttribute(field)) {
854 hint = *failureReason;
855 } else if (concreteType_->isIgnoredAttribute(field)) {
856 hint = "attribute was ignored during compilation";
857 }
858
859 throw(
860 ErrorReport(loc)
861 << "Module '"
862 << concreteType_->getJitType()->expectRef<ClassType>().name()->name()
863 << "'"
864 << " has no attribute '" << field << "' " << hint);
865 }
866
iter(const SourceRange & loc,GraphFunction & m)867 SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
868 const auto iterableModuleKind = concreteType_->getIterableModuleKind();
869 if (iterableModuleKind == IterableModuleKind::NONE) {
870 throw(
871 ErrorReport(loc)
872 << "Only constant Sequential, ModuleList, ModuleDict, or "
873 << "ParameterList can be used as an iterable");
874 }
875
876 if (iterableModuleKind == IterableModuleKind::DICT) {
877 auto module_dict = getSugaredDict(loc, m);
878 return module_dict->keys_;
879 } else if (iterableModuleKind == IterableModuleKind::LIST) {
880 auto module_dict = getSugaredDict(loc, m);
881 return module_dict->modules_;
882 } else if (iterableModuleKind == IterableModuleKind::PARAMLIST) {
883 auto module_dict = getSugaredNamedParameterList(loc, m);
884 return module_dict->modules_;
885 } else {
886 TORCH_INTERNAL_ASSERT(false);
887 }
888 }
889
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)890 std::shared_ptr<SugaredValue> PythonClassValue::attr(
891 const SourceRange& loc,
892 GraphFunction& m,
893 const std::string& field) {
894 // Resolve values from the Python object first (e.g. for static methods on
895 // this type, resolve them as functions)
896 if (auto* fn = type_->findStaticMethod(field)) {
897 return std::make_shared<FunctionValue>(fn);
898 }
899 auto py_attr = py::getattr(py_type_, field.c_str(), py::none());
900 if (!py_attr.is_none()) {
901 return toSugaredValue(py_attr, m, loc);
902 }
903
904 return ClassValue::attr(loc, m, field);
905 }
906
hasAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)907 bool PythonClassValue::hasAttr(
908 const SourceRange& loc,
909 GraphFunction& m,
910 const std::string& field) {
911 try {
912 py::getattr(py_type_, field.c_str());
913 return true;
914 } catch (py::error_already_set& e) {
915 return false;
916 }
917 }
918
setAttr(const SourceRange & loc,GraphFunction & m,const std::string & field,Value * newValue)919 void ModuleValue::setAttr(
920 const SourceRange& loc,
921 GraphFunction& m,
922 const std::string& field,
923 Value* newValue) {
924 // Forward to SimpleValue::setAttr
925 SimpleValue simple(self_);
926 simple.setAttr(loc, m, field, newValue);
927 }
928
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)929 std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
930 const SourceRange& loc,
931 GraphFunction& caller,
932 at::ArrayRef<NamedValue> args,
933 at::ArrayRef<NamedValue> kwargs,
934 size_t n_binders) {
935 std::optional<bool> result;
936 Graph& graph = *(caller.graph());
937
938 auto index = py::cast<size_t>(dispatched_fn_["index"]);
939 auto arg_name = py::str(dispatched_fn_["arg_name"]);
940
941 ErrorReport error(loc);
942 if (index < args.size()) {
943 // Dispatch flag is in arg list
944 result = constant_as<bool>(args.at(index).value(graph));
945 error << "Argument for boolean dispatch at position " << index
946 << " was not constant";
947 } else if (auto i = findInputWithName(arg_name, kwargs)) {
948 // Dispatch flag is in kwargs
949 result = constant_as<bool>(kwargs[*i].value(graph));
950 error << "Keyword argument '" << arg_name
951 << "' for boolean dispatch at position was not constant";
952 } else {
953 // Didn't find dispatch flag, so use default value
954 result = py::cast<bool>(dispatched_fn_["default"]);
955 TORCH_INTERNAL_ASSERT(result);
956 }
957
958 if (!result.has_value()) {
959 throw ErrorReport(error);
960 }
961
962 std::shared_ptr<SugaredValue> value;
963 if (*result) {
964 value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
965 } else {
966 value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
967 }
968 return value->call(loc, caller, args, kwargs, n_binders);
969 }
970
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t)971 std::shared_ptr<SugaredValue> PythonExceptionValue::call(
972 const SourceRange& loc,
973 GraphFunction& caller,
974 at::ArrayRef<NamedValue> args,
975 at::ArrayRef<NamedValue> kwargs,
976 size_t /*n_binders*/) {
977 Value* error_message = nullptr;
978 if (args.empty()) {
979 error_message = insertConstant(*caller.graph(), "", loc);
980 } else if (args.size() == 1) {
981 error_message = args.at(0).value(*caller.graph());
982 } else {
983 std::vector<Value*> message_values;
984 message_values.reserve(args.size() + kwargs.size());
985
986 for (const auto& inp : args) {
987 message_values.push_back(inp.value(*caller.graph()));
988 }
989 for (const auto& kwarg_inp : kwargs) {
990 message_values.push_back(kwarg_inp.value(*caller.graph()));
991 }
992 error_message =
993 caller.graph()
994 ->insertNode(caller.graph()->createTuple(message_values))
995 ->output();
996 }
997 Value* qualified_class_name =
998 insertConstant(*caller.graph(), exception_class_qualified_name_, loc);
999
1000 return std::make_shared<ExceptionMessageValue>(
1001 error_message, qualified_class_name);
1002 }
1003
isNamedTupleClass(const py::object & obj)1004 bool isNamedTupleClass(const py::object& obj) {
1005 auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
1006 int is_tuple_class = PyObject_IsSubclass(obj.ptr(), tuple_type);
1007 if (is_tuple_class == -1) {
1008 PyErr_Clear();
1009 return false;
1010 }
1011 return is_tuple_class == 1 && py::hasattr(obj, "_fields");
1012 }
1013
registerNamedTuple(const py::object & obj,const SourceRange & loc,const ResolutionCallback & rcb)1014 TypePtr registerNamedTuple(
1015 const py::object& obj,
1016 const SourceRange& loc,
1017 const ResolutionCallback& rcb) {
1018 TORCH_INTERNAL_ASSERT(isNamedTupleClass(obj));
1019 auto qualifiedName = c10::QualifiedName(py::cast<std::string>(
1020 py::module::import("torch._jit_internal").attr("_qualified_name")(obj)));
1021
1022 // Note: we need to pass rcb to resolve ForwardRef annotations. See
1023 // [Note: ForwardRef annotations in NamedTuple attributes]
1024 py::object props =
1025 py::module::import("torch._jit_internal")
1026 .attr("_get_named_tuple_properties")(obj, loc, py::cpp_function(rcb));
1027
1028 auto [unqualName, field_names, field_types, objects] = py::cast<std::tuple<
1029 std::string,
1030 std::vector<std::string>,
1031 std::vector<TypePtr>,
1032 std::vector<py::object>>>(props);
1033
1034 std::vector<IValue> field_defaults;
1035 auto min_default_idx = field_names.size() - objects.size();
1036 for (size_t i = min_default_idx, j = 0; i < field_names.size(); ++i, ++j) {
1037 py::object o = objects[j];
1038 auto type = tryToInferType(objects[j]);
1039 IValue ival = toIValue(objects[j], type.type());
1040 TORCH_CHECK(
1041 ival.tagKind() != "Tensor",
1042 "Tensors are"
1043 " not supported as default NamedTuple fields. Their "
1044 "mutability could lead to potential memory aliasing "
1045 "problems");
1046 field_defaults.emplace_back(ival);
1047 }
1048
1049 auto tt = TupleType::createNamed(
1050 qualifiedName, field_names, field_types, field_defaults);
1051 if (auto type = get_python_cu()->get_type(qualifiedName)) {
1052 TORCH_CHECK(
1053 type->isSubtypeOf(tt), "Can't redefine NamedTuple: ", tt->repr_str());
1054 return type;
1055 }
1056 get_python_cu()->register_type(tt);
1057 return tt;
1058 }
1059
isEnumClass(py::object obj)1060 bool isEnumClass(py::object obj) {
1061 auto enum_type_obj =
1062 py::cast<py::object>(py::module::import("enum").attr("Enum"));
1063 int ret = PyObject_IsSubclass(obj.ptr(), enum_type_obj.ptr());
1064 if (ret == -1) {
1065 PyErr_Clear();
1066 return false;
1067 }
1068 return ret == 1;
1069 }
1070
createSimpleEnumValue(const py::object & obj,GraphFunction & m,const SourceRange & loc)1071 std::shared_ptr<SugaredValue> createSimpleEnumValue(
1072 const py::object& obj,
1073 GraphFunction& m,
1074 const SourceRange& loc) {
1075 auto enum_class = obj.attr("__class__");
1076 auto enum_type =
1077 py::cast<TypePtr>(py::module::import("torch.jit.annotations")
1078 .attr("try_ann_to_type")(enum_class, loc));
1079 auto enum_ivalue = toIValue(obj, enum_type);
1080 return toSimple(m.graph()->insertConstant(enum_ivalue, loc));
1081 }
1082
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t)1083 std::shared_ptr<SugaredValue> PythonSliceClass::call(
1084 const SourceRange& loc,
1085 GraphFunction& caller,
1086 at::ArrayRef<NamedValue> args,
1087 at::ArrayRef<NamedValue> kwargs,
1088 size_t /*n_binders*/) {
1089 if (!kwargs.empty()) {
1090 throw(ErrorReport(loc) << "Slice does not accept any keyword arguments");
1091 }
1092
1093 static constexpr int64_t default_start = 0;
1094 static constexpr int64_t default_stop = std::numeric_limits<int64_t>::max();
1095 static constexpr int64_t default_step = 1;
1096 Graph& graph = *(caller.graph());
1097
1098 auto ValOr = [&](Value* given, int64_t default_val) {
1099 if (!given || given->type()->isSubtypeOf(*NoneType::get())) {
1100 return graph.insertConstant(default_val, loc);
1101 }
1102 return given;
1103 };
1104
1105 Value* start = nullptr;
1106 Value* stop = nullptr;
1107 Value* step = nullptr;
1108 size_t n = args.size();
1109 // Slice's constructor signature is Slice(start=None, stop, step=None)
1110 if (n == 1) {
1111 // Case where only `stop` is specified.
1112 start = ValOr(nullptr, default_start);
1113 stop = ValOr(args[0].value(graph), default_stop);
1114 step = ValOr(nullptr, default_step);
1115 } else if (n == 2) {
1116 // Case where `start` and `stop` are specified.
1117 start = ValOr(args[0].value(graph), default_start);
1118 stop = ValOr(args[1].value(graph), default_stop);
1119 step = ValOr(nullptr, default_step);
1120 } else if (n == 3) {
1121 // Case where `start`, `stop` and `step` are all specified.
1122 start = ValOr(args[0].value(graph), default_start);
1123 stop = ValOr(args[1].value(graph), default_stop);
1124 step = ValOr(args[2].value(graph), default_step);
1125 } else {
1126 throw(
1127 ErrorReport(loc) << "slice accepts exactly 1, 2 or 3 arguments, got: "
1128 << n);
1129 }
1130
1131 return std::make_shared<SliceValue>(start, stop, step);
1132 }
1133
toSugaredValue(py::object obj,GraphFunction & m,const SourceRange & loc,bool is_constant)1134 std::shared_ptr<SugaredValue> toSugaredValue(
1135 py::object obj,
1136 GraphFunction& m,
1137 const SourceRange& loc,
1138 bool is_constant) {
1139 // directly create SimpleValues when possible, because they are first-class
1140 // and can be re-assigned. Otherwise, this would be invalid:
1141 // f = python_constant
1142 // while ...
1143 // f = f + 1
1144 auto& g = *m.graph();
1145 if (is_constant) {
1146 if (py::isinstance<py::bool_>(obj)) {
1147 return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
1148 } else if (py::isinstance<py::int_>(obj)) {
1149 return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
1150 } else if (py::isinstance<py::float_>(obj)) {
1151 return toSimple(g.insertConstant(py::cast<double>(obj), loc));
1152 } else if (PyComplex_CheckExact(obj.ptr())) {
1153 auto c_obj = py::cast<std::complex<double>>(obj.ptr());
1154 return toSimple(
1155 g.insertConstant(static_cast<c10::complex<double>>(c_obj), loc));
1156 } else if (py::isinstance<py::str>(obj)) {
1157 return toSimple(g.insertConstant(py::cast<std::string>(obj), loc));
1158 } else if (obj.is_none()) {
1159 return toSimple(g.insertConstant(IValue(), loc));
1160 } else if (THPDevice_Check(obj.ptr())) {
1161 auto device = reinterpret_cast<THPDevice*>(obj.ptr());
1162 return toSimple(g.insertConstant(device->device));
1163 } else if (THPLayout_Check(obj.ptr())) {
1164 auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
1165 const auto v = static_cast<int64_t>(layout->layout);
1166 return toSimple(g.insertConstant(v, loc));
1167 } else if (THPMemoryFormat_Check(obj.ptr())) {
1168 auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
1169 const auto v = static_cast<int64_t>(memory_format->memory_format);
1170 return toSimple(g.insertConstant(v, loc));
1171 } else if (THPDtype_Check(obj.ptr())) {
1172 auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
1173 const auto v = static_cast<int64_t>(dtype->scalar_type);
1174 return toSimple(g.insertConstant(v, loc));
1175 } else if (THPQScheme_Check(obj.ptr())) {
1176 auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
1177 const auto v = static_cast<uint8_t>(qscheme->qscheme);
1178 return toSimple(g.insertConstant(v, loc));
1179 } else if (py::isinstance<py::tuple>(obj)) {
1180 py::tuple tup = obj;
1181 std::vector<Value*> values;
1182 values.reserve(tup.size());
1183 for (py::handle t : tup) {
1184 py::object obj = py::reinterpret_borrow<py::object>(t);
1185 values.push_back(toSugaredValue(obj, m, loc, true)->asValue(loc, m));
1186 }
1187 return toSimple(
1188 m.graph()->insertNode(m.graph()->createTuple(values))->output());
1189 }
1190 }
1191
1192 auto opoverloadpacket_type =
1193 py::module::import("torch").attr("_ops").attr("OpOverloadPacket");
1194 py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type);
1195 if (is_overloadpacket) {
1196 obj = py::getattr(obj, "op");
1197 }
1198
1199 #ifdef USE_RPC
1200 bool isRpcAvailable = py::cast<bool>(
1201 py::module::import("torch.distributed.rpc").attr("is_available")());
1202 #endif
1203
1204 if (auto callee = as_function(obj)) {
1205 return std::make_shared<FunctionValue>(callee->function_);
1206 } else if (py::isinstance<py::module>(obj)) {
1207 std::string obj_name = py::cast<py::str>(py::getattr(obj, "__name__"));
1208 if (obj_name == "torch.cuda") {
1209 return std::make_shared<CUDAPythonModuleValue>(obj);
1210 }
1211 return std::make_shared<PythonModuleValue>(obj);
1212 } else if (
1213 obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() ||
1214 obj.ptr() == py::module::import("torch.jit").attr("fork").ptr()) {
1215 return SpecialFormValue::create(prim::fork);
1216 } else if (
1217 obj.ptr() == py::module::import("torch.jit").attr("_awaitable").ptr()) {
1218 return SpecialFormValue::create(prim::awaitable);
1219 } else if (
1220 obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
1221 return SpecialFormValue::create(prim::annotate);
1222 } else if (
1223 obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) {
1224 return SpecialFormValue::create(prim::isinstance);
1225 #ifdef USE_RPC
1226 // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on.
1227 } else if (
1228 isRpcAvailable &&
1229 obj.ptr() ==
1230 py::module::import("torch.distributed.rpc").attr("rpc_async").ptr()) {
1231 return SpecialFormValue::create(prim::rpc_async);
1232 } else if (
1233 isRpcAvailable &&
1234 obj.ptr() ==
1235 py::module::import("torch.distributed.rpc").attr("rpc_sync").ptr()) {
1236 return SpecialFormValue::create(prim::rpc_sync);
1237 } else if (
1238 isRpcAvailable &&
1239 // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on.
1240 obj.ptr() ==
1241 py::module::import("torch.distributed.rpc").attr("remote").ptr()) {
1242 return SpecialFormValue::create(prim::rpc_remote);
1243 #endif
1244 } else if (auto callee = as_module(obj)) {
1245 throw(
1246 ErrorReport(loc) << "Cannot call a ScriptModule that is not"
1247 << " a submodule of the caller");
1248 }
1249 std::vector<std::pair<const char*, at::ScalarType>> tensor_names = {
1250 {"BoolTensor", at::ScalarType::Bool},
1251 {"LongTensor", at::ScalarType::Long},
1252 {"ByteTensor", at::ScalarType::Byte},
1253 {"CharTensor", at::ScalarType::Char},
1254 {"DoubleTensor", at::ScalarType::Double},
1255 {"FloatTensor", at::ScalarType::Float},
1256 {"IntTensor", at::ScalarType::Int},
1257 {"ShortTensor", at::ScalarType::Short},
1258 {"HalfTensor", at::ScalarType::Half},
1259 };
1260 for (const auto& name : tensor_names) {
1261 if (obj.ptr() == py::module::import("torch").attr(name.first).ptr()) {
1262 // torch.LongTensor and other related functions create on cpu,
1263 // TODO: add support for torch.cuda.LongTensor for gpu
1264 return LegacyTensorConstructor::create(
1265 prim::LegacyTypedConstructor, name.second, at::kCPU);
1266 }
1267 }
1268
1269 py::object builtin_name =
1270 py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);
1271 if (!builtin_name.is_none()) {
1272 return std::make_shared<BuiltinFunction>(
1273 Symbol::fromQualString(py::str(builtin_name)), std::nullopt);
1274 }
1275
1276 if (py::cast<bool>(py::module::import("torch._jit_internal")
1277 .attr("_is_exception")(obj))) {
1278 return std::make_shared<PythonExceptionValue>(obj);
1279 }
1280
1281 if (py::isinstance<py::function>(obj)) {
1282 if (typeString(obj) == "builtin_function_or_method") {
1283 throw(
1284 ErrorReport(loc) << "Python builtin " << py::str(obj)
1285 << " is currently not supported in Torchscript");
1286 }
1287 }
1288
1289 py::object dispatched_fn = py::module::import("torch._jit_internal")
1290 .attr("_try_get_dispatched_fn")(obj);
1291 if (!dispatched_fn.is_none()) {
1292 return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
1293 }
1294
1295 if (py::isinstance<ScriptClass>(obj)) {
1296 auto script_class = py::cast<ScriptClass>(obj);
1297 return std::make_shared<PythonClassValue>(
1298 script_class.class_type_.type_->expect<ClassType>(), obj);
1299 }
1300
1301 if (isNamedTupleClass(obj)) {
1302 // The use of fakeRcb here prevents us from correctly resolving ForwardRef
1303 // annotations on NamedTuple attributes for instances whose types are
1304 // inferred. See #95858 for more details, as well as
1305 // [Note: ForwardRef annotations in NamedTuple attributes]
1306 auto fakeRcb =
1307 py::module::import("torch.jit.annotations").attr("_fake_rcb");
1308 auto tuple_type =
1309 registerNamedTuple(obj, loc, fakeRcb)->expect<TupleType>();
1310 return std::make_shared<NamedTupleConstructor>(tuple_type);
1311 }
1312
1313 if (isEnumClass(obj)) {
1314 return createSugaredEnumClassFromObj(obj, m, loc);
1315 }
1316
1317 auto enum_type = py::module::import("enum").attr("Enum");
1318 py::bool_ is_enum_value = py::isinstance(obj, enum_type);
1319 if (py::cast<bool>(is_enum_value)) {
1320 return createSimpleEnumValue(obj, m, loc);
1321 }
1322
1323 py::bool_ is_class = py::module::import("inspect").attr("isclass")(obj);
1324 if (py::cast<bool>(is_class)) {
1325 py::str qualifiedName =
1326 py::module::import("torch._jit_internal").attr("_qualified_name")(obj);
1327 auto pyCu = get_python_cu();
1328 auto qualname = c10::QualifiedName(qualifiedName);
1329
1330 if (auto classType = pyCu->get_class(qualname)) {
1331 return std::make_shared<PythonClassValue>(classType, obj);
1332 } else {
1333 // If we can't get the source code for the type, it's implemented in C and
1334 // probably part of the standard library, so give up and leave it as a
1335 // call to Python
1336 bool can_compile_class =
1337 py::cast<bool>(py::module::import("torch._jit_internal")
1338 .attr("can_compile_class")(obj));
1339 if (can_compile_class) {
1340 // Register class
1341 auto rcb = py::module::import("torch._jit_internal")
1342 .attr("createResolutionCallbackForClassMethods")(obj);
1343 py::module::import("torch.jit._script")
1344 .attr("_recursive_compile_class")(obj, loc);
1345
1346 // Return class
1347 auto newClassType = pyCu->get_class(qualname);
1348 AT_ASSERT(
1349 newClassType,
1350 "Class '",
1351 qualifiedName,
1352 "' should have been compiled but was not");
1353 return std::make_shared<PythonClassValue>(newClassType, obj);
1354 }
1355 }
1356 }
1357
1358 py::bool_ isFunction = py::module::import("inspect").attr("isfunction")(obj);
1359 if (py::cast<bool>(isFunction)) {
1360 auto overloads =
1361 py::module::import("torch.jit._script").attr("_get_overloads")(obj);
1362 if (!overloads.is_none()) {
1363 auto compiled_fns = py::cast<std::vector<StrongFunctionPtr>>(overloads);
1364 return std::make_shared<FunctionValue>(std::move(compiled_fns));
1365 }
1366
1367 auto compiled_fn = py::module::import("torch.jit._recursive")
1368 .attr("try_compile_fn")(obj, loc);
1369 if (auto callee = as_function(compiled_fn)) {
1370 return std::make_shared<FunctionValue>(*callee);
1371 }
1372 }
1373 if (obj.ptr() == py::module::import("math").attr("inf").ptr()) {
1374 return toSimple(
1375 g.insertConstant(std::numeric_limits<double>::infinity(), loc));
1376 }
1377
1378 py::bool_ isMethod = py::module::import("inspect").attr("ismethod")(obj);
1379 // methods here have been explicitly annotated to not be compiled,
1380 // so they do not have the same overload and compile checks as for functions
1381 if (isFunction || isMethod) {
1382 auto rcb = py::module::import("torch._jit_internal")
1383 .attr("createResolutionCallbackFromClosure")(obj);
1384 return std::make_shared<PythonValue>(obj, rcb);
1385 }
1386
1387 if (obj.is(py::module::import("builtins").attr("slice"))) {
1388 return std::make_shared<PythonSliceClass>();
1389 }
1390
1391 return std::make_shared<PythonValue>(obj);
1392 }
1393 } // namespace torch::jit
1394