xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_sugared_value.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/python/python_sugared_value.h>
2 
3 #include <ATen/core/interned_strings.h>
4 #include <c10/core/ScalarType.h>
5 #include <pybind11/pytypes.h>
6 #include <torch/csrc/Dtype.h>
7 #include <torch/csrc/Layout.h>
8 #include <torch/csrc/MemoryFormat.h>
9 #include <torch/csrc/jit/frontend/schema_matching.h>
10 #include <torch/csrc/jit/python/module_python.h>
11 #include <torch/csrc/utils/pybind.h>
12 #include <climits>
13 #include <memory>
14 #include <sstream>
15 #include <string>
16 #include <tuple>
17 #include <vector>
18 
19 #include <Python.h>
20 
21 namespace torch::jit {
22 
typeString(py::handle h)23 std::string typeString(py::handle h) {
24   return py::str(h.get_type().attr("__name__"));
25 }
26 
as_function(const py::object & obj)27 std::optional<StrongFunctionPtr> as_function(const py::object& obj) {
28   if (py::isinstance<StrongFunctionPtr>(obj)) {
29     return py::cast<StrongFunctionPtr>(obj);
30   }
31   return std::nullopt;
32 }
33 
getSchema(const size_t n_args,const size_t n_binders,const SourceRange & loc)34 FunctionSchema PythonValue::getSchema(
35     const size_t n_args,
36     const size_t n_binders,
37     const SourceRange& loc) {
38   auto annotations = py::module::import("torch.jit.annotations");
39   const auto callable = moduleSelf_ ? py::getattr(self, "original_fn") : self;
40 
41   // Make sure the function is not a class instantiation (e.g. `Exception()`)
42   annotations.attr("check_fn")(callable, loc);
43   auto is_vararg = py::cast<bool>(annotations.attr("is_vararg")(callable));
44 
45   auto signature = annotations.attr("get_signature")(
46       callable, rcb ? *rcb : py::none(), loc, bool(moduleSelf_));
47   std::vector<Argument> args, rets;
48 
49   auto py_param_names = annotations.attr("get_param_names")(callable, n_args);
50   auto param_names = py::cast<std::vector<std::string>>(py_param_names);
51   auto names_it = param_names.begin();
52   if (moduleSelf_) {
53     if (param_names.empty()) {
54       throw(
55           ErrorReport(loc)
56           << "Non-static method does not have a self argument");
57     }
58 
59     // If there is a `self` parameter on the callable, skip it on the names list
60     args.emplace_back(Argument(*names_it, moduleSelf_->type(), {}, {}, false));
61     ++names_it;
62   }
63   if (signature.is_none()) {
64     // No type signature was provided on the callable, so make a default
65     // signature where each argument is typed as a Tensor
66     for (; names_it != param_names.end(); ++names_it) {
67       args.emplace_back(
68           /*name=*/*names_it,
69           /*type=*/TensorType::get(),
70           /*N=*/std::nullopt,
71           /*default_value=*/std::nullopt,
72           /*kwarg_only=*/false);
73     }
74 
75     // Use as many outputs as are requested to make the return type
76     TypePtr ret_type = TensorType::get();
77     if (n_binders == 0) {
78       ret_type = NoneType::get();
79     } else if (n_binders > 1) {
80       std::vector<TypePtr> tuple_values(n_binders, ret_type);
81       ret_type = TupleType::create(std::move(tuple_values));
82     }
83     rets.emplace_back(Argument("0", ret_type, {}, {}, false));
84   } else {
85     // Use the provided type signature
86     auto [arg_types, ret_type] =
87         py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
88 
89     // arg_types does not include self but param_names does, so adjust for that
90     // if needed
91     TORCH_INTERNAL_ASSERT(
92         arg_types.size() == param_names.size() - (moduleSelf_ ? 1 : 0));
93 
94     auto types_it = arg_types.begin();
95     for (; types_it != arg_types.end(); ++types_it, ++names_it) {
96       args.emplace_back(
97           /*name=*/*names_it,
98           /*type=*/std::move(*types_it),
99           /*N=*/std::nullopt,
100           /*default_value=*/std::nullopt,
101           /*kwarg_only=*/false);
102     }
103     rets.push_back(Argument("0", ret_type, {}, {}, false));
104   }
105 
106   std::string name;
107   if (py::hasattr(self, "__qualname__")) {
108     // Use the qualified name if possible
109     name = py::str(py::getattr(self, "__qualname__"));
110   } else if (py::hasattr(self, "__name__")) {
111     name = py::str(py::getattr(self, "__name__"));
112   }
113   return FunctionSchema(name, "", std::move(args), std::move(rets), is_vararg);
114 }
115 
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)116 std::shared_ptr<SugaredValue> PythonValue::call(
117     const SourceRange& loc,
118     GraphFunction& m,
119     at::ArrayRef<NamedValue> args,
120     at::ArrayRef<NamedValue> kwargs,
121     size_t n_binders) {
122   std::vector<NamedValue> argsWithSelf;
123   if (moduleSelf_) {
124     argsWithSelf.emplace_back("self", moduleSelf_);
125   }
126   argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end());
127 
128   auto schema = getSchema(argsWithSelf.size(), n_binders, loc);
129   auto inputs = toValues(*m.graph(), argsWithSelf);
130 
131   MatchedSchema matched_schema =
132       matchSchema(schema, loc, *m.graph(), argsWithSelf, kwargs);
133 
134   // If if a function is marked as dropped,
135   // we throw an exception if it is invoked.
136   if (py::cast<bool>(py::module::import("torch._jit_internal")
137                          .attr("should_drop")(self))) {
138     auto g = m.graph();
139     auto err_msg = insertConstant(
140         *g,
141         IValue(
142             "This Python function is annotated to be ignored and cannot be run"));
143     g->insert(prim::RaiseException, {err_msg}, {}, loc);
144     return std::make_shared<SimpleValue>(
145         g->insertNode(g->createUninitialized(matched_schema.return_types.at(0)))
146             ->output());
147   }
148 
149   // Release the function object so we can wrap it in a PythonOp
150   py::object func = self;
151   std::string cconv(inputs.size(), 'd');
152   Node* new_node = m.graph()->insertNode(
153       m.graph()->createPythonOp(THPObjectPtr(func.release().ptr()), cconv, {}));
154 
155   new_node->setSourceRange(loc);
156   for (auto& i : matched_schema.inputs)
157     new_node->addInput(i);
158 
159   Value* output =
160       new_node->addOutput()->setType(matched_schema.return_types.at(0));
161   return std::make_shared<SimpleValue>(output);
162 }
163 
kind() const164 std::string PythonValue::kind() const {
165   std::stringstream ss;
166   ss << "python value of type '" << typeString(self) << "'";
167   return ss.str();
168 }
169 
asTuple(const SourceRange & loc,GraphFunction & m,const std::optional<size_t> & size_hint)170 std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
171     const SourceRange& loc,
172     GraphFunction& m,
173     const std::optional<size_t>& size_hint) {
174   std::stringstream ss;
175   ss << kind() << " cannot be used as a tuple";
176   checkForAddToConstantsError(ss);
177   throw(ErrorReport(loc) << ss.str());
178 }
179 
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)180 std::shared_ptr<SugaredValue> PythonValue::attr(
181     const SourceRange& loc,
182     GraphFunction& m,
183     const std::string& field) {
184   std::stringstream ss;
185   ss << "attribute lookup is not defined on " << kind();
186   checkForAddToConstantsError(ss);
187   throw(ErrorReport(loc) << ss.str());
188 }
189 
getattr(const SourceRange & loc,const std::string & name)190 py::object PythonValue::getattr(
191     const SourceRange& loc,
192     const std::string& name) {
193   try {
194     return py::getattr(self, name.c_str());
195   } catch (py::error_already_set& e) {
196     throw(ErrorReport(loc) << "object has no attribute " << name);
197   }
198 }
199 
checkForAddToConstantsError(std::stringstream & ss)200 void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
201   auto nn = py::module::import("torch.nn");
202   if (py::isinstance(self, nn.attr("ModuleList")) ||
203       py::isinstance(self, nn.attr("Sequential"))) {
204     ss << ". Did you forget to add it to __constants__? ";
205   }
206 }
207 
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)208 std::shared_ptr<SugaredValue> PythonModuleValue::attr(
209     const SourceRange& loc,
210     GraphFunction& m,
211     const std::string& field) {
212   py::object member = getattr(loc, field);
213   // note: is_constant = true because we consider that global properties
214   // on modules like math.pi or torch.float to be constants
215   // even though it is possible, though rare, for someone to mutate them
216   return toSugaredValue(member, m, loc, /*is_constant=*/true);
217 }
218 
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)219 std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
220     const SourceRange& loc,
221     GraphFunction& m,
222     const std::string& field) {
223   // List of all the cuda operators which are supported in JIT
224   const std::unordered_set<std::string> cuda_ops = {
225       "current_stream",
226       "default_stream",
227       "current_device",
228       "_exchange_device",
229       "_maybe_exchange_device",
230       "set_device",
231       "device_index",
232       "device_count",
233       "set_stream",
234       "synchronize"};
235 
236   if (cuda_ops.find(field) != cuda_ops.end()) {
237     // Both current_device and set_device API's are a part of c10::cuda
238     // namespace. Hence, to resolve the conflict for jit, we append _ to both
239     // these APIs.
240     if (field == "current_device" || field == "set_device") {
241       return std::make_shared<BuiltinFunction>(
242           Symbol::cuda("_" + field), std::nullopt);
243     } else {
244       return std::make_shared<BuiltinFunction>(
245           Symbol::cuda(field), std::nullopt);
246     }
247   }
248 
249   if (field == "Stream" || field == "Event") {
250     auto class_type = getCustomClass("__torch__.torch.classes.cuda." + field);
251     return std::make_shared<ClassValue>(class_type);
252   }
253 
254   py::object member = getattr(loc, field);
255   // note: is_constant = true because we consider that global properties
256   // on modules like math.pi or torch.float to be constants
257   // even though it is possible, though rare, for someone to mutate them
258   return toSugaredValue(member, m, loc, /*is_constant=*/true);
259 }
260 
asValue(const SourceRange & loc,GraphFunction & m)261 Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) {
262   return self_;
263 }
264 
asTupleValue(const SourceRange & loc,GraphFunction & m)265 SugaredValuePtr ModuleValue::asTupleValue(
266     const SourceRange& loc,
267     GraphFunction& m) {
268   if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
269     auto dict = getSugaredDict(loc, m);
270     auto mods = dict->getModules();
271     return mods;
272   }
273   throw(
274       ErrorReport(loc)
275       << "Only ModuleList or Sequential modules can be used as tuple");
276 }
277 
areAllSubmodulesSubtypeOf(const TypePtr & ty,std::ostream * why_not) const278 bool ModuleValue::areAllSubmodulesSubtypeOf(
279     const TypePtr& ty,
280     std::ostream* why_not) const {
281   const auto& self_type = concreteType_->getJitType()->expect<ClassType>();
282   for (size_t i = 0; i < self_type->numAttributes(); ++i) {
283     const auto& attr_type = self_type->getAttribute(i);
284     if (attr_type->is_module()) {
285       std::stringstream ss;
286       if (!attr_type->isSubtypeOfExt(ty, &ss)) {
287         if (why_not) {
288           *why_not << "Attribute " << self_type->getAttributeName(i)
289                    << " is not of annotated type " << ty->annotation_str()
290                    << ": " << ss.str();
291         }
292 
293         return false;
294       }
295     }
296   }
297 
298   return true;
299 }
300 
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)301 SugaredValuePtr ModuleValue::getitem(
302     const SourceRange& loc,
303     GraphFunction& m,
304     Value* idx,
305     TypePtr type_hint) {
306   if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
307     if (type_hint) {
308       // Check that all submodules comply with the type hint.
309       std::stringstream ss;
310       if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
311         throw(ErrorReport(loc) << ss.str());
312       }
313 
314       // Emit a prim::ModuleContainerIndex operator. This is needed because
315       // it's difficult to construct a list in the graph representing the
316       // ModuleList and use aten::__getitem__ ops to index into it because
317       // any call to ModuleList.setitem would invalidate that emitted list.
318       auto graph = m.graph();
319       auto* getitem_node = graph->insertNode(
320           graph->create(prim::ModuleContainerIndex, {self_, idx}));
321       getitem_node->output(0)->setType(type_hint);
322       return std::make_shared<SimpleValue>(getitem_node->output(0));
323     } else {
324       return getSugaredDict(loc, m)->getModules()->getitem(
325           loc, m, idx, type_hint);
326     }
327   } else if (
328       concreteType_->getIterableModuleKind() == IterableModuleKind::PARAMLIST) {
329     return getSugaredNamedParameterList(loc, m)->getModules()->getitem(
330         loc, m, idx, type_hint);
331   } else if (
332       concreteType_->getIterableModuleKind() == IterableModuleKind::DICT ||
333       concreteType_->getIterableModuleKind() == IterableModuleKind::PARAMDICT) {
334     if (auto ivalue = toIValue(idx)) {
335       std::shared_ptr<SugaredDict> sd;
336       if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
337         sd = getSugaredDict(loc, m);
338       } else if (
339           concreteType_->getIterableModuleKind() ==
340           IterableModuleKind::PARAMDICT) {
341         sd = getSugaredNamedParameterDict(loc, m);
342       }
343       auto idx_str = ivalue->toStringRef();
344       auto keys_iter = sd->keys_;
345       auto module_values_iter = sd->modules_;
346       for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
347         auto key = keys_iter->tup_.at(i);
348         auto key_str = toIValue(key->asValue(loc, m))->toStringRef();
349         if (key_str == idx_str) {
350           return module_values_iter->tup_.at(i);
351         }
352       }
353       throw(ErrorReport(loc) << "Key Error, " << idx_str);
354     } else if (type_hint) {
355       // Check that all submodules comply with the type hint.
356       std::stringstream ss;
357       if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
358         throw(ErrorReport(loc) << ss.str());
359       }
360 
361       // Emit a prim::ModuleContainerIndex operator. This is needed because
362       // it's difficult to construct a dict in the graph representing the
363       // ModuleDict and use aten::__getitem__ ops to index into it because
364       // any call to ModuleDict.setAttr would invalidate that emitted dict.
365       auto graph = m.graph();
366       auto* getitem_node = graph->insertNode(
367           graph->create(prim::ModuleContainerIndex, {self_, idx}));
368       getitem_node->output(0)->setType(type_hint);
369       return std::make_shared<SimpleValue>(getitem_node->output(0));
370     }
371     throw(
372         ErrorReport(loc)
373         << "Unable to extract string literal index. "
374         << "ModuleDict indexing is only supported with string literals. "
375         << "For example, 'i = \"a\"; self.layers[i](x)' will fail because i is not a literal. "
376         << "Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): out = v(inp)'");
377   }
378   throw(
379       ErrorReport(loc)
380       << "Only ModuleList, Sequential, ModuleDict, "
381       << "ParameterList, and ParameterDict modules are subscriptable");
382 }
383 
checkInterface(const SourceRange & loc,GraphFunction & m,const std::shared_ptr<ModuleValue> & self,const std::string & field)384 void checkInterface(
385     const SourceRange& loc,
386     GraphFunction& m,
387     const std::shared_ptr<ModuleValue>& self,
388     const std::string& field) {
389   if (self->asValue(loc, m)->type()->cast<InterfaceType>()) {
390     throw(
391         ErrorReport(loc)
392         << "Could not compile " << field
393         << "() because module is an interface type. Please file issue.");
394   }
395 }
396 
recurseThroughNestedModules(const SourceRange & loc,GraphFunction & m,std::vector<SugaredValuePtr> & keys,std::vector<SugaredValuePtr> & values,std::shared_ptr<ModuleValue> & self,const std::string & prefix,const std::string & field)397 void recurseThroughNestedModules(
398     const SourceRange& loc,
399     GraphFunction& m,
400     std::vector<SugaredValuePtr>& keys,
401     std::vector<SugaredValuePtr>& values,
402     std::shared_ptr<ModuleValue>& self,
403     const std::string& prefix,
404     const std::string& field) {
405   auto prefix_value =
406       std::make_shared<SimpleValue>(insertConstant(*m.graph(), prefix));
407 
408   keys.push_back(prefix_value);
409   values.push_back(self);
410 
411   checkInterface(loc, m, self, field);
412   auto module_dict = self->getSugaredDict(loc, m);
413   auto keys_iter = module_dict->keys_;
414   auto module_values_iter = module_dict->modules_;
415   for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
416     std::shared_ptr<SugaredValue> module_sugared_value =
417         module_values_iter->tup_.at(i);
418     auto module_value =
419         std::dynamic_pointer_cast<ModuleValue>(module_sugared_value);
420 
421     auto keys_value = keys_iter->tup_.at(i);
422     auto key_string = toIValue(keys_value->asValue(loc, m))->toStringRef();
423     std::string submodule_prefix = prefix;
424     if (!prefix.empty()) {
425       submodule_prefix = prefix + ".";
426     }
427     submodule_prefix += key_string;
428     recurseThroughNestedModules(
429         loc, m, keys, values, module_value, submodule_prefix, field);
430   };
431 }
432 
getSugaredNamedBufferDict(const SourceRange & loc,GraphFunction & m)433 std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
434     const SourceRange& loc,
435     GraphFunction& m) {
436   std::vector<std::string> paramNames;
437   std::vector<SugaredValuePtr> values;
438 
439   const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
440   for (size_t i = 0; i < selfType->numAttributes(); ++i) {
441     if (selfType->is_buffer(i)) {
442       paramNames.push_back(selfType->getAttributeName(i));
443     }
444   }
445 
446   std::vector<SugaredValuePtr> keys;
447   for (const auto& name : paramNames) {
448     auto name_v =
449         std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
450     m.graph()->insertGetAttr(self_, name);
451     values.push_back(tryGetAttr(loc, m, name));
452     keys.push_back(name_v);
453   }
454 
455   return std::make_shared<SugaredDict>(
456       std::make_shared<ModuleValue>(self_, concreteType_),
457       std::make_shared<SugaredTupleValue>(keys),
458       std::make_shared<SugaredTupleValue>(values));
459 }
460 
getSugaredNamedParameterList(const SourceRange & loc,GraphFunction & m)461 std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedParameterList(
462     const SourceRange& loc,
463     GraphFunction& m) {
464   std::vector<std::string> paramNames;
465   std::vector<SugaredValuePtr> values;
466 
467   const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
468   for (size_t i = 0; i < selfType->numAttributes(); ++i) {
469     if (selfType->is_parameter(i)) {
470       paramNames.push_back(selfType->getAttributeName(i));
471     }
472   }
473 
474   std::vector<SugaredValuePtr> keys;
475   for (const auto& name : paramNames) {
476     auto name_v =
477         std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
478     m.graph()->insertGetAttr(self_, name);
479     values.push_back(tryGetAttr(loc, m, name));
480     keys.push_back(name_v);
481   }
482 
483   return std::make_shared<SugaredDict>(
484       std::make_shared<ModuleValue>(self_, concreteType_),
485       std::make_shared<SugaredTupleValue>(keys),
486       std::make_shared<SugaredTupleValue>(values));
487 }
488 
getSugaredDict(const SourceRange & loc,GraphFunction & m)489 std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
490     const SourceRange& loc,
491     GraphFunction& m) {
492   std::vector<std::string> submoduleNames;
493   const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
494   for (size_t i = 0; i < selfType->numAttributes(); ++i) {
495     const auto& attrType = selfType->getAttribute(i);
496     if (attrType->is_module()) {
497       submoduleNames.push_back(selfType->getAttributeName(i));
498     }
499   }
500 
501   std::vector<SugaredValuePtr> keys;
502   std::vector<SugaredValuePtr> values;
503   for (const auto& name : submoduleNames) {
504     auto name_v =
505         std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
506     Value* module_v = m.graph()->insertGetAttr(self_, name);
507     auto mod_v = std::make_shared<ModuleValue>(
508         module_v, concreteType_->findSubmoduleConcreteType(name));
509 
510     keys.push_back(name_v);
511     values.push_back(mod_v);
512   }
513 
514   return std::make_shared<SugaredDict>(
515       std::make_shared<ModuleValue>(self_, concreteType_),
516       std::make_shared<SugaredTupleValue>(keys),
517       std::make_shared<SugaredTupleValue>(values));
518 }
519 
getSugaredNamedParameterDict(const SourceRange & loc,GraphFunction & m)520 std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedParameterDict(
521     const SourceRange& loc,
522     GraphFunction& m) {
523   std::vector<std::string> paramNames;
524   const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
525   for (size_t i = 0; i < selfType->numAttributes(); ++i) {
526     if (selfType->is_parameter(i)) {
527       paramNames.push_back(selfType->getAttributeName(i));
528     }
529   }
530 
531   std::vector<SugaredValuePtr> keys;
532   std::vector<SugaredValuePtr> values;
533   for (const auto& name : paramNames) {
534     auto name_v =
535         std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
536     m.graph()->insertGetAttr(self_, name);
537     auto val = tryGetAttr(loc, m, name);
538     TORCH_INTERNAL_ASSERT(val != nullptr, "Could not find attribute ", name);
539     values.push_back(val);
540     keys.push_back(name_v);
541   }
542 
543   return std::make_shared<SugaredDict>(
544       std::make_shared<ModuleValue>(self_, concreteType_),
545       std::make_shared<SugaredTupleValue>(keys),
546       std::make_shared<SugaredTupleValue>(values));
547 }
548 
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)549 std::shared_ptr<SugaredValue> SugaredDict::attr(
550     const SourceRange& loc,
551     GraphFunction& m,
552     const std::string& field) {
553   // Recursive compilation does not maintain module aliasing,
554   // so we do not add uniqueness checks on
555   // "children"/"named_children"/"modules"/"named_modules"
556   checkInterface(loc, m, self_, field);
557   if (field == "keys") {
558     return std::make_shared<ModuleDictMethod>(keys_, "keys");
559   } else if (field == "values" || field == "children") {
560     return std::make_shared<ModuleDictMethod>(modules_, field);
561   } else if (
562       field == "items" || field == "named_children" ||
563       field == "named_buffers") {
564     auto iterator = std::make_shared<IterableTree>();
565     iterator->addChild(loc, m, keys_);
566     iterator->addChild(loc, m, modules_);
567     return std::make_shared<ModuleDictMethod>(iterator, field);
568   } else if (field == "named_modules" || field == "modules") {
569     std::vector<SugaredValuePtr> keys;
570     std::vector<SugaredValuePtr> values;
571     recurseThroughNestedModules(loc, m, keys, values, self_, "", field);
572     if (field == "modules") {
573       return std::make_shared<ModuleDictMethod>(
574           std::make_shared<SugaredTupleValue>(values), field);
575     } else {
576       auto iterator = std::make_shared<IterableTree>();
577       iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(keys));
578       iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(values));
579       return std::make_shared<ModuleDictMethod>(iterator, field);
580     }
581   }
582   TORCH_INTERNAL_ASSERT(false);
583 }
584 
createSugaredEnumClassFromObj(const py::object & obj,GraphFunction & m,const SourceRange & loc)585 std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
586     const py::object& obj,
587     GraphFunction& m,
588     const SourceRange& loc) {
589   auto annotation_type = py::module::import("torch.jit.annotations")
590                              .attr("try_ann_to_type")(obj, loc);
591   TORCH_INTERNAL_ASSERT(!annotation_type.is_none());
592   auto type = py::cast<TypePtr>(annotation_type);
593   auto enum_type = type->expect<EnumType>();
594   return std::make_shared<SugaredEnumClass>(enum_type);
595 }
596 
597 // helper function for instantiating a SugaredValue from an IValue
toSugaredValue(const IValue & v,GraphFunction & m,const SourceRange & loc)598 std::shared_ptr<SugaredValue> toSugaredValue(
599     const IValue& v,
600     GraphFunction& m,
601     const SourceRange& loc) {
602   if (v.isTuple()) {
603     auto tp = v.toTuple();
604     std::vector<Value*> values;
605     values.reserve(tp->elements().size());
606     for (const auto& e : tp->elements()) {
607       values.push_back(toSugaredValue(e, m, loc)->asValue(loc, m));
608     }
609     return toSimple(
610         m.graph()->insertNode(m.graph()->createTuple(values))->output());
611   } else {
612     return toSimple(m.graph()->insertConstant(v, loc));
613   }
614 }
615 
616 // This method controls how we desugar attribute lookups on ScriptModules
tryGetAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)617 std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
618     const SourceRange& loc,
619     GraphFunction& m,
620     const std::string& field) {
621   // 1. Look inside Module object for the field.
622   const auto& selfType_ = concreteType_->getJitType();
623   if (selfType_->cast<InterfaceType>()) {
624     return std::make_shared<SimpleValue>(self_)->attr(loc, m, field);
625   }
626 
627   const auto& selfType = selfType_->expect<ClassType>();
628 
629   if (selfType->hasAttribute(field) &&
630       selfType->getAttribute(field)->is_module()) {
631     // ...if it's a submodule, return it as a new ModuleValue.
632     if (const auto submoduleConcreteType =
633             concreteType_->findSubmoduleConcreteType(field)) {
634       return std::make_shared<ModuleValue>(
635           m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
636     }
637 
638     return std::make_shared<ModuleValue>(
639         m.graph()->insertGetAttr(self_, field),
640         ConcreteModuleType::fromJitType(selfType->getAttribute(field)));
641   } else if (selfType->hasAttribute(field) || selfType->findMethod(field)) {
642     // ...otherwise, methods, parameters, attributes, and buffers are all
643     // first class so they get returned as SimpleValues
644     return std::make_shared<SimpleValue>(self_)->attr(loc, m, field);
645   } else if (selfType->hasConstant(field)) {
646     auto v = selfType->getConstant(field);
647     return toSugaredValue(v, m, loc);
648   }
649 
650   // 2. Special case: for module dicts we manually desugar items(), keys(),
651   // values() calls into the appropriate method.
652   if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
653     if (field == "items" || field == "keys" || field == "values") {
654       return getSugaredDict(loc, m)->attr(loc, m, field);
655     }
656   }
657 
658   if (field == "named_modules" || field == "modules" || field == "children" ||
659       field == "named_children") {
660     return getSugaredDict(loc, m)->attr(loc, m, field);
661   }
662 
663   if (field == "named_buffers") {
664     return getSugaredNamedBufferDict(loc, m)->attr(loc, m, field);
665   }
666 
667   // 3. Check if this is the name of an overloaded method.
668 
669   // This can also be a call to a non-script module, or a plain
670   // python method. If so return this as a python value.
671   if (const auto overloads = concreteType_->findOverloads(field)) {
672     return std::make_shared<MethodValue>(self_, *overloads);
673   }
674 
675   // 4. Check if it's a function attribute.
676   if (const auto fnAttr = concreteType_->findFunctionAttribute(field)) {
677     return std::make_shared<FunctionValue>(*fnAttr);
678   } else if (const auto builtin = concreteType_->findBuiltinFunction(field)) {
679     return std::make_shared<BuiltinFunction>(*builtin, /*self=*/std::nullopt);
680   }
681 
682   // 5. Check if it's an attribute of the original Python class that this
683   // ScriptModule was derived from. The only class attributes we handle are
684   // methods.
685   const auto maybePyClass = concreteType_->getPyClass();
686   if (!maybePyClass) {
687     // ConcreteType doesn't always have an originating Python class, e.g. if it
688     // was derived from a serialized ScriptModule. In this case, we've exhausted
689     // our options for attr lookup.
690     return nullptr;
691   }
692   py::object unboundMethod = py::getattr(
693       *maybePyClass, field.c_str(), pybind11::cast<pybind11::none>(Py_None));
694 
695   if (py::isinstance<py::function>(unboundMethod)) {
696     bool isStaticFn =
697         py::cast<bool>(py::module::import("torch._jit_internal")
698                            .attr("is_static_fn")(*maybePyClass, field.c_str()));
699     if (isStaticFn) {
700       // Functions within the module annotated with @staticmethod do not need
701       // binding.
702       py::object staticFn =
703           py::module::import("torch._jit_internal")
704               .attr("get_static_fn")(*maybePyClass, field.c_str());
705       return toSugaredValue(staticFn, m, loc);
706     }
707     // For Python methods that we're trying to call directly, we need to bind
708     // the method to a self. (see the documentation for lazy_bind in Python for
709     // more info).
710     bool isIgnoredFn =
711         py::cast<bool>(py::module::import("torch._jit_internal")
712                            .attr("is_ignored_fn")(unboundMethod));
713     if (isIgnoredFn) {
714       // Create a generated ScriptModule type with module_ set as cpp_module
715       auto boundMethod = py::module::import("torch.jit._recursive")
716                              .attr("lazy_bind")(concreteType_, unboundMethod);
717       TORCH_CHECK(py::isinstance<py::function>(boundMethod));
718       auto rcb =
719           py::module::import("torch._jit_internal")
720               .attr("createResolutionCallbackFromClosure")(unboundMethod);
721       return std::make_shared<PythonValue>(boundMethod, rcb, self_);
722     }
723 
724     // If we reach here, it's because this is a "normal" method that just hasn't
725     // been compiled yet (directly exported methods would have been returned by
726     // step 1). Just compile it.
727     auto stub =
728         py::module::import("torch.jit._recursive")
729             .attr("compile_unbound_method")(concreteType_, unboundMethod);
730     TORCH_INTERNAL_ASSERT(!stub.is_none());
731     // Look up the attribute again, it will be available as a compiled method.
732     return attr(loc, m, field);
733   }
734 
735   return nullptr;
736 }
737 
hasAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)738 bool ModuleValue::hasAttr(
739     const SourceRange& loc,
740     GraphFunction& m,
741     const std::string& field) {
742   return tryGetAttr(loc, m, field) != nullptr;
743 }
744 
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)745 std::shared_ptr<SugaredValue> ModuleValue::call(
746     const SourceRange& loc,
747     GraphFunction& caller,
748     at::ArrayRef<NamedValue> args,
749     at::ArrayRef<NamedValue> kwargs,
750     size_t n_binders) {
751   c10::ClassTypePtr class_type = concreteType_->getJitType()->cast<ClassType>();
752   bool have_pre_hooks = class_type && !class_type->getForwardPreHooks().empty();
753   bool have_hooks = class_type && !class_type->getForwardHooks().empty();
754 
755   std::vector<Value*> arg_values;
756   std::vector<NamedValue> pre_hook_result;
757   Value* forward_input = nullptr;
758   std::shared_ptr<Graph> calling_graph = caller.graph();
759 
760   if (have_pre_hooks || have_hooks) {
761     // convert forward args into tuple for forward hooks
762     // (the input of eager hooks are always tuples)
763     for (const auto& sv : args) {
764       arg_values.push_back(sv.value(*calling_graph));
765     }
766     forward_input =
767         calling_graph->insertNode(calling_graph->createTuple(arg_values))
768             ->output();
769   }
770 
771   // call pre_hooks
772   if (have_pre_hooks) {
773     for (const auto& hook : class_type->getForwardPreHooks()) {
774       TORCH_INTERNAL_ASSERT(forward_input != nullptr);
775       Value* pre_hook_output =
776           FunctionValue(hook)
777               .call(
778                   loc,
779                   caller,
780                   {NamedValue(self_), NamedValue(forward_input)},
781                   kwargs,
782                   n_binders)
783               ->asValue(loc, caller);
784       if (pre_hook_output->type() != NoneType::get()) {
785         if (pre_hook_output->type()->kind() != TypeKind::TupleType) {
786           pre_hook_output =
787               calling_graph
788                   ->insertNode(calling_graph->createTuple({pre_hook_output}))
789                   ->output();
790         }
791         forward_input = pre_hook_output;
792       }
793     }
794     // de-tuple pre_hook output for forward
795     at::ArrayRef<Value*> output_nodes =
796         calling_graph
797             ->insertNode(calling_graph->createTupleUnpack(forward_input))
798             ->outputs();
799     for (auto& output_node : output_nodes) {
800       pre_hook_result.emplace_back(output_node);
801     }
802     if (!args.empty()) { // only replace input if it existed
803       args = pre_hook_result;
804     }
805   }
806 
807   // call forward
808   std::shared_ptr<SugaredValue> forwardSV =
809       attr(loc, caller, "forward")->call(loc, caller, args, kwargs, n_binders);
810   Value* forward_output = forwardSV->asValue(loc, caller);
811 
812   // call hooks
813   if (have_hooks) {
814     for (const auto& hook : class_type->getForwardHooks()) {
815       Value* forward_hook_output = FunctionValue(hook)
816                                        .call(
817                                            loc,
818                                            caller,
819                                            {NamedValue(self_),
820                                             NamedValue(forward_input),
821                                             NamedValue(forward_output)},
822                                            kwargs,
823                                            n_binders)
824                                        ->asValue(loc, caller);
825       if (forward_hook_output->type() != NoneType::get()) {
826         forward_output = forward_hook_output;
827       }
828     }
829   }
830 
831   return std::make_shared<SimpleValue>(forward_output);
832 }
833 
834 // This method controls how we desugar attribute lookups on ScriptModules.
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)835 std::shared_ptr<SugaredValue> ModuleValue::attr(
836     const SourceRange& loc,
837     GraphFunction& m,
838     const std::string& field) {
839   if (auto attr = tryGetAttr(loc, m, field)) {
840     return attr;
841   }
842 
843   // Check if it's a property.
844   auto prop =
845       concreteType_->getJitType()->expectRef<ClassType>().getProperty(field);
846   if (prop) {
847     return MethodValue(self_, prop->getter->name())
848         .call(loc, m, {}, {}, /*n_binders=*/1);
849   }
850 
851   // We don't define this attr. Bailout with a hint to the user.
852   std::string hint;
853   if (auto failureReason = concreteType_->findFailedAttribute(field)) {
854     hint = *failureReason;
855   } else if (concreteType_->isIgnoredAttribute(field)) {
856     hint = "attribute was ignored during compilation";
857   }
858 
859   throw(
860       ErrorReport(loc)
861       << "Module '"
862       << concreteType_->getJitType()->expectRef<ClassType>().name()->name()
863       << "'"
864       << " has no attribute '" << field << "' " << hint);
865 }
866 
iter(const SourceRange & loc,GraphFunction & m)867 SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
868   const auto iterableModuleKind = concreteType_->getIterableModuleKind();
869   if (iterableModuleKind == IterableModuleKind::NONE) {
870     throw(
871         ErrorReport(loc)
872         << "Only constant Sequential, ModuleList, ModuleDict, or "
873         << "ParameterList can be used as an iterable");
874   }
875 
876   if (iterableModuleKind == IterableModuleKind::DICT) {
877     auto module_dict = getSugaredDict(loc, m);
878     return module_dict->keys_;
879   } else if (iterableModuleKind == IterableModuleKind::LIST) {
880     auto module_dict = getSugaredDict(loc, m);
881     return module_dict->modules_;
882   } else if (iterableModuleKind == IterableModuleKind::PARAMLIST) {
883     auto module_dict = getSugaredNamedParameterList(loc, m);
884     return module_dict->modules_;
885   } else {
886     TORCH_INTERNAL_ASSERT(false);
887   }
888 }
889 
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)890 std::shared_ptr<SugaredValue> PythonClassValue::attr(
891     const SourceRange& loc,
892     GraphFunction& m,
893     const std::string& field) {
894   // Resolve values from the Python object first (e.g. for static methods on
895   // this type, resolve them as functions)
896   if (auto* fn = type_->findStaticMethod(field)) {
897     return std::make_shared<FunctionValue>(fn);
898   }
899   auto py_attr = py::getattr(py_type_, field.c_str(), py::none());
900   if (!py_attr.is_none()) {
901     return toSugaredValue(py_attr, m, loc);
902   }
903 
904   return ClassValue::attr(loc, m, field);
905 }
906 
hasAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)907 bool PythonClassValue::hasAttr(
908     const SourceRange& loc,
909     GraphFunction& m,
910     const std::string& field) {
911   try {
912     py::getattr(py_type_, field.c_str());
913     return true;
914   } catch (py::error_already_set& e) {
915     return false;
916   }
917 }
918 
setAttr(const SourceRange & loc,GraphFunction & m,const std::string & field,Value * newValue)919 void ModuleValue::setAttr(
920     const SourceRange& loc,
921     GraphFunction& m,
922     const std::string& field,
923     Value* newValue) {
924   // Forward to SimpleValue::setAttr
925   SimpleValue simple(self_);
926   simple.setAttr(loc, m, field, newValue);
927 }
928 
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)929 std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
930     const SourceRange& loc,
931     GraphFunction& caller,
932     at::ArrayRef<NamedValue> args,
933     at::ArrayRef<NamedValue> kwargs,
934     size_t n_binders) {
935   std::optional<bool> result;
936   Graph& graph = *(caller.graph());
937 
938   auto index = py::cast<size_t>(dispatched_fn_["index"]);
939   auto arg_name = py::str(dispatched_fn_["arg_name"]);
940 
941   ErrorReport error(loc);
942   if (index < args.size()) {
943     // Dispatch flag is in arg list
944     result = constant_as<bool>(args.at(index).value(graph));
945     error << "Argument for boolean dispatch at position " << index
946           << " was not constant";
947   } else if (auto i = findInputWithName(arg_name, kwargs)) {
948     // Dispatch flag is in kwargs
949     result = constant_as<bool>(kwargs[*i].value(graph));
950     error << "Keyword argument '" << arg_name
951           << "' for boolean dispatch at position was not constant";
952   } else {
953     // Didn't find dispatch flag, so use default value
954     result = py::cast<bool>(dispatched_fn_["default"]);
955     TORCH_INTERNAL_ASSERT(result);
956   }
957 
958   if (!result.has_value()) {
959     throw ErrorReport(error);
960   }
961 
962   std::shared_ptr<SugaredValue> value;
963   if (*result) {
964     value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
965   } else {
966     value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
967   }
968   return value->call(loc, caller, args, kwargs, n_binders);
969 }
970 
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t)971 std::shared_ptr<SugaredValue> PythonExceptionValue::call(
972     const SourceRange& loc,
973     GraphFunction& caller,
974     at::ArrayRef<NamedValue> args,
975     at::ArrayRef<NamedValue> kwargs,
976     size_t /*n_binders*/) {
977   Value* error_message = nullptr;
978   if (args.empty()) {
979     error_message = insertConstant(*caller.graph(), "", loc);
980   } else if (args.size() == 1) {
981     error_message = args.at(0).value(*caller.graph());
982   } else {
983     std::vector<Value*> message_values;
984     message_values.reserve(args.size() + kwargs.size());
985 
986     for (const auto& inp : args) {
987       message_values.push_back(inp.value(*caller.graph()));
988     }
989     for (const auto& kwarg_inp : kwargs) {
990       message_values.push_back(kwarg_inp.value(*caller.graph()));
991     }
992     error_message =
993         caller.graph()
994             ->insertNode(caller.graph()->createTuple(message_values))
995             ->output();
996   }
997   Value* qualified_class_name =
998       insertConstant(*caller.graph(), exception_class_qualified_name_, loc);
999 
1000   return std::make_shared<ExceptionMessageValue>(
1001       error_message, qualified_class_name);
1002 }
1003 
isNamedTupleClass(const py::object & obj)1004 bool isNamedTupleClass(const py::object& obj) {
1005   auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
1006   int is_tuple_class = PyObject_IsSubclass(obj.ptr(), tuple_type);
1007   if (is_tuple_class == -1) {
1008     PyErr_Clear();
1009     return false;
1010   }
1011   return is_tuple_class == 1 && py::hasattr(obj, "_fields");
1012 }
1013 
registerNamedTuple(const py::object & obj,const SourceRange & loc,const ResolutionCallback & rcb)1014 TypePtr registerNamedTuple(
1015     const py::object& obj,
1016     const SourceRange& loc,
1017     const ResolutionCallback& rcb) {
1018   TORCH_INTERNAL_ASSERT(isNamedTupleClass(obj));
1019   auto qualifiedName = c10::QualifiedName(py::cast<std::string>(
1020       py::module::import("torch._jit_internal").attr("_qualified_name")(obj)));
1021 
1022   // Note: we need to pass rcb to resolve ForwardRef annotations. See
1023   // [Note: ForwardRef annotations in NamedTuple attributes]
1024   py::object props =
1025       py::module::import("torch._jit_internal")
1026           .attr("_get_named_tuple_properties")(obj, loc, py::cpp_function(rcb));
1027 
1028   auto [unqualName, field_names, field_types, objects] = py::cast<std::tuple<
1029       std::string,
1030       std::vector<std::string>,
1031       std::vector<TypePtr>,
1032       std::vector<py::object>>>(props);
1033 
1034   std::vector<IValue> field_defaults;
1035   auto min_default_idx = field_names.size() - objects.size();
1036   for (size_t i = min_default_idx, j = 0; i < field_names.size(); ++i, ++j) {
1037     py::object o = objects[j];
1038     auto type = tryToInferType(objects[j]);
1039     IValue ival = toIValue(objects[j], type.type());
1040     TORCH_CHECK(
1041         ival.tagKind() != "Tensor",
1042         "Tensors are"
1043         " not supported as default NamedTuple fields. Their "
1044         "mutability could lead to potential memory aliasing "
1045         "problems");
1046     field_defaults.emplace_back(ival);
1047   }
1048 
1049   auto tt = TupleType::createNamed(
1050       qualifiedName, field_names, field_types, field_defaults);
1051   if (auto type = get_python_cu()->get_type(qualifiedName)) {
1052     TORCH_CHECK(
1053         type->isSubtypeOf(tt), "Can't redefine NamedTuple: ", tt->repr_str());
1054     return type;
1055   }
1056   get_python_cu()->register_type(tt);
1057   return tt;
1058 }
1059 
isEnumClass(py::object obj)1060 bool isEnumClass(py::object obj) {
1061   auto enum_type_obj =
1062       py::cast<py::object>(py::module::import("enum").attr("Enum"));
1063   int ret = PyObject_IsSubclass(obj.ptr(), enum_type_obj.ptr());
1064   if (ret == -1) {
1065     PyErr_Clear();
1066     return false;
1067   }
1068   return ret == 1;
1069 }
1070 
createSimpleEnumValue(const py::object & obj,GraphFunction & m,const SourceRange & loc)1071 std::shared_ptr<SugaredValue> createSimpleEnumValue(
1072     const py::object& obj,
1073     GraphFunction& m,
1074     const SourceRange& loc) {
1075   auto enum_class = obj.attr("__class__");
1076   auto enum_type =
1077       py::cast<TypePtr>(py::module::import("torch.jit.annotations")
1078                             .attr("try_ann_to_type")(enum_class, loc));
1079   auto enum_ivalue = toIValue(obj, enum_type);
1080   return toSimple(m.graph()->insertConstant(enum_ivalue, loc));
1081 }
1082 
call(const SourceRange & loc,GraphFunction & caller,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t)1083 std::shared_ptr<SugaredValue> PythonSliceClass::call(
1084     const SourceRange& loc,
1085     GraphFunction& caller,
1086     at::ArrayRef<NamedValue> args,
1087     at::ArrayRef<NamedValue> kwargs,
1088     size_t /*n_binders*/) {
1089   if (!kwargs.empty()) {
1090     throw(ErrorReport(loc) << "Slice does not accept any keyword arguments");
1091   }
1092 
1093   static constexpr int64_t default_start = 0;
1094   static constexpr int64_t default_stop = std::numeric_limits<int64_t>::max();
1095   static constexpr int64_t default_step = 1;
1096   Graph& graph = *(caller.graph());
1097 
1098   auto ValOr = [&](Value* given, int64_t default_val) {
1099     if (!given || given->type()->isSubtypeOf(*NoneType::get())) {
1100       return graph.insertConstant(default_val, loc);
1101     }
1102     return given;
1103   };
1104 
1105   Value* start = nullptr;
1106   Value* stop = nullptr;
1107   Value* step = nullptr;
1108   size_t n = args.size();
1109   // Slice's constructor signature is Slice(start=None, stop, step=None)
1110   if (n == 1) {
1111     // Case where only `stop` is specified.
1112     start = ValOr(nullptr, default_start);
1113     stop = ValOr(args[0].value(graph), default_stop);
1114     step = ValOr(nullptr, default_step);
1115   } else if (n == 2) {
1116     // Case where `start` and `stop` are specified.
1117     start = ValOr(args[0].value(graph), default_start);
1118     stop = ValOr(args[1].value(graph), default_stop);
1119     step = ValOr(nullptr, default_step);
1120   } else if (n == 3) {
1121     // Case where `start`, `stop` and `step` are all specified.
1122     start = ValOr(args[0].value(graph), default_start);
1123     stop = ValOr(args[1].value(graph), default_stop);
1124     step = ValOr(args[2].value(graph), default_step);
1125   } else {
1126     throw(
1127         ErrorReport(loc) << "slice accepts exactly 1, 2 or 3 arguments, got: "
1128                          << n);
1129   }
1130 
1131   return std::make_shared<SliceValue>(start, stop, step);
1132 }
1133 
toSugaredValue(py::object obj,GraphFunction & m,const SourceRange & loc,bool is_constant)1134 std::shared_ptr<SugaredValue> toSugaredValue(
1135     py::object obj,
1136     GraphFunction& m,
1137     const SourceRange& loc,
1138     bool is_constant) {
1139   // directly create SimpleValues when possible, because they are first-class
1140   // and can be re-assigned. Otherwise, this would be invalid:
1141   // f = python_constant
1142   // while ...
1143   //   f = f + 1
1144   auto& g = *m.graph();
1145   if (is_constant) {
1146     if (py::isinstance<py::bool_>(obj)) {
1147       return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
1148     } else if (py::isinstance<py::int_>(obj)) {
1149       return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
1150     } else if (py::isinstance<py::float_>(obj)) {
1151       return toSimple(g.insertConstant(py::cast<double>(obj), loc));
1152     } else if (PyComplex_CheckExact(obj.ptr())) {
1153       auto c_obj = py::cast<std::complex<double>>(obj.ptr());
1154       return toSimple(
1155           g.insertConstant(static_cast<c10::complex<double>>(c_obj), loc));
1156     } else if (py::isinstance<py::str>(obj)) {
1157       return toSimple(g.insertConstant(py::cast<std::string>(obj), loc));
1158     } else if (obj.is_none()) {
1159       return toSimple(g.insertConstant(IValue(), loc));
1160     } else if (THPDevice_Check(obj.ptr())) {
1161       auto device = reinterpret_cast<THPDevice*>(obj.ptr());
1162       return toSimple(g.insertConstant(device->device));
1163     } else if (THPLayout_Check(obj.ptr())) {
1164       auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
1165       const auto v = static_cast<int64_t>(layout->layout);
1166       return toSimple(g.insertConstant(v, loc));
1167     } else if (THPMemoryFormat_Check(obj.ptr())) {
1168       auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
1169       const auto v = static_cast<int64_t>(memory_format->memory_format);
1170       return toSimple(g.insertConstant(v, loc));
1171     } else if (THPDtype_Check(obj.ptr())) {
1172       auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
1173       const auto v = static_cast<int64_t>(dtype->scalar_type);
1174       return toSimple(g.insertConstant(v, loc));
1175     } else if (THPQScheme_Check(obj.ptr())) {
1176       auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
1177       const auto v = static_cast<uint8_t>(qscheme->qscheme);
1178       return toSimple(g.insertConstant(v, loc));
1179     } else if (py::isinstance<py::tuple>(obj)) {
1180       py::tuple tup = obj;
1181       std::vector<Value*> values;
1182       values.reserve(tup.size());
1183       for (py::handle t : tup) {
1184         py::object obj = py::reinterpret_borrow<py::object>(t);
1185         values.push_back(toSugaredValue(obj, m, loc, true)->asValue(loc, m));
1186       }
1187       return toSimple(
1188           m.graph()->insertNode(m.graph()->createTuple(values))->output());
1189     }
1190   }
1191 
1192   auto opoverloadpacket_type =
1193       py::module::import("torch").attr("_ops").attr("OpOverloadPacket");
1194   py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type);
1195   if (is_overloadpacket) {
1196     obj = py::getattr(obj, "op");
1197   }
1198 
1199 #ifdef USE_RPC
1200   bool isRpcAvailable = py::cast<bool>(
1201       py::module::import("torch.distributed.rpc").attr("is_available")());
1202 #endif
1203 
1204   if (auto callee = as_function(obj)) {
1205     return std::make_shared<FunctionValue>(callee->function_);
1206   } else if (py::isinstance<py::module>(obj)) {
1207     std::string obj_name = py::cast<py::str>(py::getattr(obj, "__name__"));
1208     if (obj_name == "torch.cuda") {
1209       return std::make_shared<CUDAPythonModuleValue>(obj);
1210     }
1211     return std::make_shared<PythonModuleValue>(obj);
1212   } else if (
1213       obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() ||
1214       obj.ptr() == py::module::import("torch.jit").attr("fork").ptr()) {
1215     return SpecialFormValue::create(prim::fork);
1216   } else if (
1217       obj.ptr() == py::module::import("torch.jit").attr("_awaitable").ptr()) {
1218     return SpecialFormValue::create(prim::awaitable);
1219   } else if (
1220       obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
1221     return SpecialFormValue::create(prim::annotate);
1222   } else if (
1223       obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) {
1224     return SpecialFormValue::create(prim::isinstance);
1225 #ifdef USE_RPC
1226     // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on.
1227   } else if (
1228       isRpcAvailable &&
1229       obj.ptr() ==
1230           py::module::import("torch.distributed.rpc").attr("rpc_async").ptr()) {
1231     return SpecialFormValue::create(prim::rpc_async);
1232   } else if (
1233       isRpcAvailable &&
1234       obj.ptr() ==
1235           py::module::import("torch.distributed.rpc").attr("rpc_sync").ptr()) {
1236     return SpecialFormValue::create(prim::rpc_sync);
1237   } else if (
1238       isRpcAvailable &&
1239       // RPC module is only avaialble  when build flag "USE_DISTRIBUTED" is on.
1240       obj.ptr() ==
1241           py::module::import("torch.distributed.rpc").attr("remote").ptr()) {
1242     return SpecialFormValue::create(prim::rpc_remote);
1243 #endif
1244   } else if (auto callee = as_module(obj)) {
1245     throw(
1246         ErrorReport(loc) << "Cannot call a ScriptModule that is not"
1247                          << " a submodule of the caller");
1248   }
1249   std::vector<std::pair<const char*, at::ScalarType>> tensor_names = {
1250       {"BoolTensor", at::ScalarType::Bool},
1251       {"LongTensor", at::ScalarType::Long},
1252       {"ByteTensor", at::ScalarType::Byte},
1253       {"CharTensor", at::ScalarType::Char},
1254       {"DoubleTensor", at::ScalarType::Double},
1255       {"FloatTensor", at::ScalarType::Float},
1256       {"IntTensor", at::ScalarType::Int},
1257       {"ShortTensor", at::ScalarType::Short},
1258       {"HalfTensor", at::ScalarType::Half},
1259   };
1260   for (const auto& name : tensor_names) {
1261     if (obj.ptr() == py::module::import("torch").attr(name.first).ptr()) {
1262       // torch.LongTensor and other related functions create on cpu,
1263       // TODO: add support for torch.cuda.LongTensor for gpu
1264       return LegacyTensorConstructor::create(
1265           prim::LegacyTypedConstructor, name.second, at::kCPU);
1266     }
1267   }
1268 
1269   py::object builtin_name =
1270       py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);
1271   if (!builtin_name.is_none()) {
1272     return std::make_shared<BuiltinFunction>(
1273         Symbol::fromQualString(py::str(builtin_name)), std::nullopt);
1274   }
1275 
1276   if (py::cast<bool>(py::module::import("torch._jit_internal")
1277                          .attr("_is_exception")(obj))) {
1278     return std::make_shared<PythonExceptionValue>(obj);
1279   }
1280 
1281   if (py::isinstance<py::function>(obj)) {
1282     if (typeString(obj) == "builtin_function_or_method") {
1283       throw(
1284           ErrorReport(loc) << "Python builtin " << py::str(obj)
1285                            << " is currently not supported in Torchscript");
1286     }
1287   }
1288 
1289   py::object dispatched_fn = py::module::import("torch._jit_internal")
1290                                  .attr("_try_get_dispatched_fn")(obj);
1291   if (!dispatched_fn.is_none()) {
1292     return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
1293   }
1294 
1295   if (py::isinstance<ScriptClass>(obj)) {
1296     auto script_class = py::cast<ScriptClass>(obj);
1297     return std::make_shared<PythonClassValue>(
1298         script_class.class_type_.type_->expect<ClassType>(), obj);
1299   }
1300 
1301   if (isNamedTupleClass(obj)) {
1302     // The use of fakeRcb here prevents us from correctly resolving ForwardRef
1303     // annotations on NamedTuple attributes for instances whose types are
1304     // inferred. See #95858 for more details, as well as
1305     // [Note: ForwardRef annotations in NamedTuple attributes]
1306     auto fakeRcb =
1307         py::module::import("torch.jit.annotations").attr("_fake_rcb");
1308     auto tuple_type =
1309         registerNamedTuple(obj, loc, fakeRcb)->expect<TupleType>();
1310     return std::make_shared<NamedTupleConstructor>(tuple_type);
1311   }
1312 
1313   if (isEnumClass(obj)) {
1314     return createSugaredEnumClassFromObj(obj, m, loc);
1315   }
1316 
1317   auto enum_type = py::module::import("enum").attr("Enum");
1318   py::bool_ is_enum_value = py::isinstance(obj, enum_type);
1319   if (py::cast<bool>(is_enum_value)) {
1320     return createSimpleEnumValue(obj, m, loc);
1321   }
1322 
1323   py::bool_ is_class = py::module::import("inspect").attr("isclass")(obj);
1324   if (py::cast<bool>(is_class)) {
1325     py::str qualifiedName =
1326         py::module::import("torch._jit_internal").attr("_qualified_name")(obj);
1327     auto pyCu = get_python_cu();
1328     auto qualname = c10::QualifiedName(qualifiedName);
1329 
1330     if (auto classType = pyCu->get_class(qualname)) {
1331       return std::make_shared<PythonClassValue>(classType, obj);
1332     } else {
1333       // If we can't get the source code for the type, it's implemented in C and
1334       // probably part of the standard library, so give up and leave it as a
1335       // call to Python
1336       bool can_compile_class =
1337           py::cast<bool>(py::module::import("torch._jit_internal")
1338                              .attr("can_compile_class")(obj));
1339       if (can_compile_class) {
1340         // Register class
1341         auto rcb = py::module::import("torch._jit_internal")
1342                        .attr("createResolutionCallbackForClassMethods")(obj);
1343         py::module::import("torch.jit._script")
1344             .attr("_recursive_compile_class")(obj, loc);
1345 
1346         // Return class
1347         auto newClassType = pyCu->get_class(qualname);
1348         AT_ASSERT(
1349             newClassType,
1350             "Class '",
1351             qualifiedName,
1352             "' should have been compiled but was not");
1353         return std::make_shared<PythonClassValue>(newClassType, obj);
1354       }
1355     }
1356   }
1357 
1358   py::bool_ isFunction = py::module::import("inspect").attr("isfunction")(obj);
1359   if (py::cast<bool>(isFunction)) {
1360     auto overloads =
1361         py::module::import("torch.jit._script").attr("_get_overloads")(obj);
1362     if (!overloads.is_none()) {
1363       auto compiled_fns = py::cast<std::vector<StrongFunctionPtr>>(overloads);
1364       return std::make_shared<FunctionValue>(std::move(compiled_fns));
1365     }
1366 
1367     auto compiled_fn = py::module::import("torch.jit._recursive")
1368                            .attr("try_compile_fn")(obj, loc);
1369     if (auto callee = as_function(compiled_fn)) {
1370       return std::make_shared<FunctionValue>(*callee);
1371     }
1372   }
1373   if (obj.ptr() == py::module::import("math").attr("inf").ptr()) {
1374     return toSimple(
1375         g.insertConstant(std::numeric_limits<double>::infinity(), loc));
1376   }
1377 
1378   py::bool_ isMethod = py::module::import("inspect").attr("ismethod")(obj);
1379   // methods here have been explicitly annotated to not be compiled,
1380   // so they do not have the same overload and compile checks as for functions
1381   if (isFunction || isMethod) {
1382     auto rcb = py::module::import("torch._jit_internal")
1383                    .attr("createResolutionCallbackFromClosure")(obj);
1384     return std::make_shared<PythonValue>(obj, rcb);
1385   }
1386 
1387   if (obj.is(py::module::import("builtins").attr("slice"))) {
1388     return std::make_shared<PythonSliceClass>();
1389   }
1390 
1391   return std::make_shared<PythonValue>(obj);
1392 }
1393 } // namespace torch::jit
1394