1 #include <pybind11/detail/common.h>
2 #include <pybind11/pytypes.h>
3 #include <torch/csrc/jit/api/object.h>
4 #include <torch/csrc/jit/python/script_init.h>
5 #include <torch/csrc/utils/pybind.h>
6
7 #include <caffe2/serialize/versions.h>
8 #include <torch/csrc/Device.h>
9 #include <torch/csrc/DynamicTypes.h>
10 #include <torch/csrc/jit/api/module.h>
11 #include <torch/csrc/jit/frontend/ir_emitter.h>
12 #include <torch/csrc/jit/frontend/sugared_value.h>
13 #include <torch/csrc/jit/mobile/code.h>
14 #include <torch/csrc/jit/mobile/compatibility/backport.h>
15 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
16 #include <torch/csrc/jit/mobile/file_format.h>
17 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
18 #include <torch/csrc/jit/mobile/import.h>
19 #include <torch/csrc/jit/mobile/module.h>
20 #include <torch/csrc/jit/mobile/quantization.h>
21 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
22 #include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
23 #include <torch/csrc/jit/operator_upgraders/utils.h>
24 #include <torch/csrc/jit/operator_upgraders/version_map.h>
25 #include <torch/csrc/jit/python/module_python.h>
26 #include <torch/csrc/jit/python/python_ivalue.h>
27 #include <torch/csrc/jit/python/python_sugared_value.h>
28 #include <torch/csrc/jit/serialization/export_bytecode.h>
29 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
30 #include <torch/csrc/jit/serialization/import.h>
31 #include <torch/csrc/jit/testing/file_check.h>
32
33 #include <c10/util/Exception.h>
34 #include <c10/util/intrusive_ptr.h>
35 #include <c10/util/irange.h>
36 #include <torch/csrc/jit/frontend/parser.h>
37 #include <torch/csrc/jit/frontend/tracer.h>
38 #include <torch/csrc/jit/ir/constants.h>
39 #include <torch/csrc/jit/ir/graph_utils.h>
40 #include <torch/csrc/jit/ir/irparser.h>
41 #include <torch/csrc/jit/passes/inliner.h>
42 #include <torch/csrc/jit/passes/shape_analysis.h>
43 #include <torch/csrc/jit/python/pybind_utils.h>
44 #include <torch/csrc/jit/python/python_dict.h>
45 #include <torch/csrc/jit/python/python_list.h>
46 #include <torch/csrc/jit/python/python_tracer.h>
47 #include <torch/csrc/jit/runtime/graph_executor.h>
48 #include <torch/csrc/jit/runtime/instruction.h>
49 #include <torch/csrc/jit/runtime/interpreter.h>
50 #include <torch/csrc/jit/runtime/logging.h>
51 #include <torch/csrc/jit/serialization/export_bytecode.h>
52 #include <torch/csrc/jit/serialization/import_source.h>
53 #include <torch/csrc/jit/serialization/pickle.h>
54 #include <torch/csrc/jit/serialization/python_print.h>
55 #include <torch/csrc/jit/testing/hooks_for_testing.h>
56
57 #include <torch/csrc/api/include/torch/ordered_dict.h>
58
59 #include <ATen/ATen.h>
60 #include <ATen/core/function_schema.h>
61 #include <ATen/core/ivalue.h>
62 #include <ATen/core/qualified_name.h>
63
64 #include <pybind11/functional.h>
65 #include <pybind11/pybind11.h>
66 #include <pybind11/stl.h>
67 #include <pybind11/stl_bind.h>
68 #include <torch/csrc/jit/mobile/train/export_data.h>
69 #include <cstddef>
70 #include <memory>
71 #include <sstream>
72 #include <string>
73 #include <tuple>
74 #include <utility>
75 #include <vector>
76
77 #include <fmt/format.h>
78
79 namespace torch::jit {
80
81 using ::c10::Argument;
82 using ::c10::FunctionSchema;
83
84 using FunctionDefaults = std::unordered_map<std::string, py::object>;
85 using ClassMethodDefaults = std::unordered_map<std::string, FunctionDefaults>;
86
87 namespace {
88
89 // A resolver that will inspect the outer Python scope to find `name`.
90 struct PythonResolver : public Resolver {
PythonResolvertorch::jit::__anon8636bec80111::PythonResolver91 explicit PythonResolver(ResolutionCallback rcb) : rcb_(std::move(rcb)) {}
92
93 /**
94 * While compiling classes, the class type we're compiling will not be
95 * available in Python, since we haven't fowner_ defining the class yet. So
96 * in order to make the class type available to its own methods, we need to
97 * explicitly resolve it.
98 *
99 * @param rcb Python function to resolve a name to its Python object in the
100 * enclosing scope
101 * @param classname The unqualified classname of the class currently being
102 * compiled.
103 * @param classType The class's type.
104 */
PythonResolvertorch::jit::__anon8636bec80111::PythonResolver105 explicit PythonResolver(
106 ResolutionCallback rcb,
107 std::string classname,
108 ClassTypePtr classType)
109 : rcb_(std::move(rcb)),
110 classname_(std::move(classname)),
111 classType_(std::move(classType)) {}
112
resolveValuetorch::jit::__anon8636bec80111::PythonResolver113 std::shared_ptr<SugaredValue> resolveValue(
114 const std::string& name,
115 GraphFunction& m,
116 const SourceRange& loc) override {
117 pybind11::gil_scoped_acquire ag;
118 py::object obj = rcb_(name);
119 if (obj.is_none()) {
120 return nullptr;
121 }
122 return toSugaredValue(obj, m, loc);
123 }
124
isNamedTupleClasstorch::jit::__anon8636bec80111::PythonResolver125 static bool isNamedTupleClass(py::object obj) {
126 auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
127 return PyObject_IsSubclass(obj.ptr(), tuple_type) &&
128 py::hasattr(obj, "_fields");
129 }
130
resolveTypeFromObjecttorch::jit::__anon8636bec80111::PythonResolver131 TypePtr resolveTypeFromObject(const py::object& obj, const SourceRange& loc) {
132 if (py::isinstance<ScriptClass>(obj)) {
133 auto script_class = py::cast<ScriptClass>(obj);
134 return script_class.class_type_.type_;
135 }
136
137 py::bool_ isClass = py::module::import("inspect").attr("isclass")(obj);
138 if (!py::cast<bool>(isClass)) {
139 return nullptr;
140 }
141
142 if (isNamedTupleClass(obj)) {
143 return registerNamedTuple(obj, loc, rcb_);
144 }
145
146 auto qualifiedName = c10::QualifiedName(
147 py::cast<std::string>(py::module::import("torch._jit_internal")
148 .attr("_qualified_name")(obj)));
149
150 return get_python_cu()->get_type(qualifiedName);
151 }
152
resolveTypetorch::jit::__anon8636bec80111::PythonResolver153 TypePtr resolveType(const std::string& name, const SourceRange& loc)
154 override {
155 if (classType_ && name == classname_) {
156 return classType_;
157 }
158 pybind11::gil_scoped_acquire ag;
159 py::object obj = rcb_(name);
160 if (obj.is_none()) {
161 return nullptr;
162 }
163
164 auto annotation_type =
165 py::module::import("torch.jit.annotations")
166 .attr("try_ann_to_type")(obj, loc, py::cpp_function(rcb_));
167 if (!annotation_type.is_none()) {
168 return py::cast<TypePtr>(annotation_type);
169 }
170 return resolveTypeFromObject(obj, loc);
171 }
172
173 private:
174 ResolutionCallback rcb_;
175 std::string classname_;
176 ClassTypePtr classType_;
177 };
178
pythonResolver(const ResolutionCallback & rcb)179 std::shared_ptr<PythonResolver> pythonResolver(const ResolutionCallback& rcb) {
180 return std::make_shared<PythonResolver>(rcb);
181 }
pythonResolver(const ResolutionCallback & rcb,std::string classname,ClassTypePtr classType)182 std::shared_ptr<PythonResolver> pythonResolver(
183 const ResolutionCallback& rcb,
184 std::string classname,
185 ClassTypePtr classType) {
186 return std::make_shared<PythonResolver>(
187 rcb, std::move(classname), std::move(classType));
188 }
189
checkOverloadDecl(const Decl & new_decl,const Decl & old_decl)190 void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
191 const auto& new_params = new_decl.params();
192 const auto& old_params = old_decl.params();
193
194 // TODO. same number of parameters not strictly necessary.
195 TORCH_INTERNAL_ASSERT(
196 new_params.size() == old_params.size(),
197 "Overload must have same number of parameters\n",
198 new_decl.range(),
199 old_decl.range());
200 for (const auto i : c10::irange(new_decl.params().size())) {
201 TORCH_INTERNAL_ASSERT(
202 new_params[i].ident().name() == old_params[i].ident().name(),
203 "Overload parameters must have the same names\n",
204 new_params[i].ident(),
205 old_params[i].ident());
206 }
207 }
208
tryCalculateDefaultParam(const Argument & arg,const py::object & def_value)209 std::optional<IValue> tryCalculateDefaultParam(
210 const Argument& arg,
211 const py::object& def_value) {
212 auto n = arg.N();
213 auto list_type = arg.type()->cast<ListType>();
214 try {
215 if (n && *n > 0 && list_type) {
216 // BroadcastingList, allow default values T for arg types List[T]
217 return toIValue(def_value, list_type->getElementType());
218 } else {
219 return toIValue(def_value, arg.type());
220 }
221 } catch (...) {
222 return std::nullopt;
223 }
224 }
225
226 // An overloaded function may have a default that does not subtype all overloads
227 // @overload
228 // def foo(x: str)
229 // def foo(x=1)
calcOverloadedFunctionDefaults(const FunctionSchema & schema,const FunctionDefaults & defaults)230 FunctionDefaults calcOverloadedFunctionDefaults(
231 const FunctionSchema& schema,
232 const FunctionDefaults& defaults) {
233 FunctionDefaults updated_defaults;
234 for (const auto& arg : schema.arguments()) {
235 const std::string& arg_name = arg.name();
236 auto value = defaults.find(arg_name);
237 if (value == defaults.end()) {
238 continue;
239 }
240 auto maybe_ivalue = tryCalculateDefaultParam(arg, value->second);
241 if (maybe_ivalue) {
242 updated_defaults[arg_name] = value->second;
243 }
244 }
245 return updated_defaults;
246 }
247
248 } // namespace
249
checkMutableFunctionDefault(const py::object & def_arg)250 bool checkMutableFunctionDefault(const py::object& def_arg) {
251 if (py::isinstance<py::list>(def_arg) || py::isinstance<py::dict>(def_arg)) {
252 return true;
253 }
254 if (py::isinstance<py::tuple>(def_arg)) {
255 auto pytuple = def_arg.cast<py::tuple>();
256 for (py::handle t : pytuple) {
257 py::object obj = py::reinterpret_borrow<py::object>(t);
258 if (checkMutableFunctionDefault(obj)) {
259 return true;
260 }
261 }
262 }
263 return false;
264 }
265
checkMutableFunctionDefault(const SourceRange & range,const Argument & arg,const py::object & def_arg)266 void checkMutableFunctionDefault(
267 const SourceRange& range,
268 const Argument& arg,
269 const py::object& def_arg) {
270 if (checkMutableFunctionDefault(def_arg) || arg.type()->cast<ClassType>()) {
271 throw(
272 ErrorReport(range)
273 << "Mutable default parameters are not supported because Python binds them to the function"
274 << " and they persist across function calls.\n As a workaround, make the default None and instantiate"
275 << " the default parameter within the body of the function. Found "
276 << def_arg.get_type() << " on parameter " << arg.name());
277 }
278 }
279
getSchemaWithNameAndDefaults(const SourceRange & range,const FunctionSchema & schema,const std::optional<std::string> & new_name,const FunctionDefaults & default_args)280 FunctionSchema getSchemaWithNameAndDefaults(
281 const SourceRange& range,
282 const FunctionSchema& schema,
283 const std::optional<std::string>& new_name,
284 const FunctionDefaults& default_args) {
285 std::vector<Argument> new_args;
286 for (auto& arg : schema.arguments()) {
287 auto it = default_args.find(arg.name());
288 if (it != default_args.end()) {
289 checkMutableFunctionDefault(range, arg, it->second);
290 std::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
291 if (!value) {
292 ErrorReport error(range);
293 error << "Expected a default value of type " << arg.type()->repr_str()
294 << " on parameter \"" << arg.name() << "\".";
295 if (arg.is_inferred_type()) {
296 error << "Because \"" << arg.name()
297 << "\" was not annotated with an explicit type "
298 << "it is assumed to be type 'Tensor'.";
299 }
300 throw ErrorReport(error);
301 }
302 new_args.emplace_back(
303 arg.name(), arg.type(), arg.N(), *value, arg.kwarg_only());
304 } else {
305 new_args.push_back(arg);
306 }
307 }
308 return FunctionSchema(
309 new_name.value_or(schema.name()),
310 schema.overload_name(),
311 new_args,
312 schema.returns(),
313 schema.is_vararg(),
314 schema.is_varret());
315 }
316
mergeDefaultsAndExtraParametersToOverloadDecl(const Decl & overload_decl,const Decl & impl_decl,const FunctionDefaults & defaults)317 static Decl mergeDefaultsAndExtraParametersToOverloadDecl(
318 const Decl& overload_decl,
319 const Decl& impl_decl,
320 const FunctionDefaults& defaults) {
321 std::vector<Param> adjusted_params;
322 const auto& overload_params = overload_decl.params();
323 const auto& impl_params = impl_decl.params();
324
325 // following PEP specification that the following should work:
326 // @overload
327 // def mouse_event(x1: int, y1: int) -> ClickEvent: ...
328 // ...
329 // def mouse_event(x1: int, y1: int, x2: Optional[int] = None, y2:
330 // Optional[int] = None)
331 TORCH_CHECK(
332 overload_params.size() <= impl_params.size(),
333 "Overload should not have more parameters than implementation function",
334 overload_decl.range(),
335 impl_decl.range());
336
337 for (const auto i : c10::irange(overload_params.size())) {
338 auto overload_name = overload_params[i].ident().name();
339 auto impl_name = impl_params[i].ident().name();
340 if (overload_name != impl_name) {
341 throw(
342 ErrorReport(overload_decl.range())
343 << "Overload parameters must have the same names. "
344 << "Found " << overload_name << " and " << impl_name
345 << " on argument " << i);
346 }
347 adjusted_params.push_back(overload_params[i]);
348 }
349 for (size_t i = overload_params.size(); i < impl_params.size(); ++i) {
350 if (!defaults.count(impl_params[i].ident().name())) {
351 throw(
352 ErrorReport(impl_decl.range())
353 << "Expected to find default parameter on argument"
354 << impl_params[i].ident().name()
355 << " because it is not defined on the overloaded declaration");
356 }
357 if (!impl_params[i].type().present()) {
358 throw(
359 ErrorReport(impl_decl.range())
360 << "Parameters not specified on the overloaded declaration must have a type annotation in the implementation function."
361 << " Did not find type for param " << impl_params[i].ident().name());
362 }
363 adjusted_params.push_back(impl_params[i]);
364 }
365 return Decl::create(
366 overload_decl.range(),
367 List<Param>::create(overload_decl.range(), adjusted_params),
368 overload_decl.return_type());
369 }
370
script_compile_overloaded_function(const c10::QualifiedName & name,const Decl & overload_decl,const Def & implementation_def,const ResolutionCallback & rcb,const FunctionDefaults & implementation_defaults,const py::object & signature)371 static StrongFunctionPtr script_compile_overloaded_function(
372 const c10::QualifiedName& name,
373 const Decl& overload_decl,
374 const Def& implementation_def,
375 const ResolutionCallback& rcb,
376 const FunctionDefaults& implementation_defaults,
377 const py::object& signature) {
378 if (signature.is_none()) {
379 throw(
380 ErrorReport(overload_decl.range())
381 << "Must explicitly add type annotations to overloaded functions");
382 }
383
384 auto adjusted_decl = mergeDefaultsAndExtraParametersToOverloadDecl(
385 overload_decl, implementation_def.decl(), implementation_defaults);
386 auto new_def = implementation_def.withDecl(adjusted_decl);
387 auto cu = get_python_cu();
388 auto defined_functions = cu->define(
389 QualifiedName(name.prefix()),
390 /*properties=*/{},
391 /*propResolvers=*/{},
392 {new_def},
393 {pythonResolver(rcb)},
394 nullptr,
395 true);
396 TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
397 auto& defined = defined_functions[0];
398 FunctionDefaults updated_defaults = calcOverloadedFunctionDefaults(
399 defined->getSchema(), implementation_defaults);
400 defined->setSchema(getSchemaWithNameAndDefaults(
401 new_def.range(),
402 defined->getSchema(),
403 new_def.name().name(),
404 updated_defaults));
405 StrongFunctionPtr ret(std::move(cu), defined);
406 didFinishEmitFunction(ret);
407 return ret;
408 }
409
script_compile_function(const c10::QualifiedName & name,const Def & def,const FunctionDefaults & defaults,const ResolutionCallback & rcb)410 static StrongFunctionPtr script_compile_function(
411 const c10::QualifiedName& name,
412 const Def& def,
413 const FunctionDefaults& defaults,
414 const ResolutionCallback& rcb) {
415 auto cu = get_python_cu();
416 auto defined_functions = cu->define(
417 QualifiedName(name.prefix()),
418 /*properties=*/{},
419 /*propResolvers=*/{},
420 {def},
421 {pythonResolver(rcb)},
422 nullptr,
423 true);
424 TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
425 auto& defined = defined_functions[0];
426 defined->setSchema(getSchemaWithNameAndDefaults(
427 def.range(), defined->getSchema(), def.name().name(), defaults));
428 StrongFunctionPtr ret(std::move(cu), defined);
429 didFinishEmitFunction(ret);
430 return ret;
431 }
432
433 struct VISIBILITY_HIDDEN ModuleSelf : public Self {
ModuleSelftorch::jit::ModuleSelf434 ModuleSelf(std::shared_ptr<ConcreteModuleType> concreteType)
435 : Self(), concreteType_(std::move(concreteType)) {}
436
makeSugaredtorch::jit::ModuleSelf437 std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
438 v->setType(getClassType());
439 return std::make_shared<ModuleValue>(v, concreteType_);
440 }
441
getClassTypetorch::jit::ModuleSelf442 ClassTypePtr getClassType() const override {
443 return concreteType_->getJitType()->expect<ClassType>();
444 }
445
446 private:
447 std::shared_ptr<ConcreteModuleType> concreteType_;
448 };
449
_propagate_shapes(Graph & graph,std::vector<at::Tensor> inputs,bool with_grad=false)450 static std::shared_ptr<Graph> _propagate_shapes(
451 Graph& graph,
452 std::vector<at::Tensor> inputs,
453 bool with_grad = false) {
454 Stack stack(inputs.begin(), inputs.end());
455 auto retval = graph.copy();
456 setInputTensorTypes(*retval, stack, /*complete=*/false);
457 PropagateInputShapes(retval);
458 return retval;
459 }
460
_propagate_and_assign_input_shapes(Graph & graph,const std::vector<at::Tensor> & inputs,const std::vector<int> & param_count_list,bool with_grad=false,bool propagate=true)461 static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
462 Graph& graph,
463 const std::vector<at::Tensor>& inputs,
464 const std::vector<int>& param_count_list,
465 bool with_grad = false,
466 bool propagate = true) {
467 auto retval = graph.copy();
468 setInputTensorTypes(
469 *retval, fmap<IValue>(inputs), /*complete=*/true, param_count_list);
470 if (propagate) {
471 PropagateInputShapes(retval);
472 }
473 return retval;
474 }
475
addFunctionToModule(Module & module,const StrongFunctionPtr & func)476 void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
477 // Make a graph with a fake self argument
478 auto graph = toGraphFunction(*func.function_).graph()->copy();
479 auto v = graph->insertInput(0, "self");
480 v->setType(module._ivalue()->type());
481 const auto name = QualifiedName(*module.type()->name(), "forward");
482 auto method =
483 module._ivalue()->compilation_unit()->create_function(name, graph);
484 module.type()->addMethod(method);
485 }
486
487 // this is used in our test suite to check that we correctly preserved type tags
ivalue_tags_match(const Module & lhs,const Module & rhs)488 bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
489 struct Work {
490 IValue a;
491 IValue b;
492 };
493 std::unordered_set<const void*> visited;
494 std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}};
495 while (!work.empty()) {
496 Work item = work.back();
497 work.pop_back();
498 if (item.a.isPtrType()) {
499 // uncomment to debug type matching errors
500 // std::cout << "MATCHING " << /*item.a <<*/ "(" << *item.a.type() << ") "
501 // << item.a.internalToPointer() << " " << /*item.b <<*/ " ("
502 // << *item.b.type() << ") " << item.b.internalToPointer() <<
503 // "\n";
504
505 if (visited.count(item.a.internalToPointer())) {
506 continue;
507 }
508 visited.emplace(item.a.internalToPointer());
509 }
510 if (!unshapedType(item.b.type())
511 ->isSubtypeOf(unshapedType(item.b.type()))) {
512 // Since named types are saved and loaded in the test suite, we cannot
513 // expect them to be equal. We should still check their slots however.
514 if (!item.a.type()->cast<c10::NamedType>()) {
515 return false;
516 }
517 }
518 // check tags for objects that contain subobjects
519 if (item.a.isObject()) {
520 auto ao = item.a.toObject();
521 auto bo = item.b.toObject();
522 for (size_t i = 0; i < ao->slots().size(); ++i) {
523 work.emplace_back(Work{ao->slots().at(i), bo->slots().at(i)});
524 }
525 } else if (item.a.isTuple()) {
526 auto at = item.a.toTuple();
527 auto bt = item.b.toTuple();
528 for (size_t i = 0; i < at->elements().size(); ++i) {
529 work.emplace_back(Work{at->elements().at(i), bt->elements().at(i)});
530 }
531 } else if (item.a.isList()) {
532 auto al = item.a.toList();
533 auto bl = item.b.toList();
534 for (const auto i : c10::irange(al.size())) {
535 work.emplace_back(Work{al.get(i), bl.get(i)});
536 }
537 } else if (item.a.isGenericDict()) {
538 auto ad = item.a.toGenericDict();
539 auto bd = item.b.toGenericDict();
540 for (auto& item : ad) {
541 // Dictionaory keys cannot contain List/Dicts that require tags
542 // so we do not have to check them.
543 // Furthermore without ordered dicts it is expensive to find the
544 // equivalent key
545 work.emplace_back(Work{item.value(), bd.at(item.key())});
546 }
547 } else if (item.a.isFuture()) {
548 auto af = item.a.toFuture();
549 auto bf = item.b.toFuture();
550 af->wait();
551 bf->wait();
552 work.emplace_back(Work{af->value(), bf->value()});
553 }
554 }
555
556 return true;
557 }
558
559 // helper used to implement ._parameters, ._buffers, ._modules dicts
560 // inside of script nn.Module
561 template <typename Policy>
562 struct slot_dict_impl {
slot_dict_impltorch::jit::slot_dict_impl563 slot_dict_impl(ModulePtr module) : module_(std::move(module)) {}
containstorch::jit::slot_dict_impl564 bool contains(const std::string& name) const {
565 if (auto slot = module_->type()->findAttributeSlot(name)) {
566 if (Policy::valid(module_->type(), *slot, module_->getSlot(*slot))) {
567 return true;
568 }
569 }
570 return false;
571 }
572
itemstorch::jit::slot_dict_impl573 std::vector<std::pair<std::string, py::object>> items() const {
574 std::vector<std::pair<std::string, py::object>> result;
575 for (size_t i = 0, N = module_->type()->numAttributes(); i < N; ++i) {
576 if (Policy::valid(module_->type(), i, module_->getSlot(i))) {
577 result.emplace_back(
578 module_->type()->getAttributeName(i),
579 toPyObject(module_->getSlot(i)));
580 }
581 }
582 return result;
583 }
584
setattrtorch::jit::slot_dict_impl585 void setattr(const std::string& name, py::object value) {
586 const TypePtr& type = module_->type()->getAttribute(name);
587 Module(module_).setattr(name, toIValue(std::move(value), type));
588 }
589
getattrtorch::jit::slot_dict_impl590 py::object getattr(const std::string& name) {
591 return toPyObject(Module(module_).attr(name));
592 }
593
bindtorch::jit::slot_dict_impl594 static void bind(const py::module& m, const char* name) {
595 py::class_<slot_dict_impl<Policy>>(m, name)
596 .def(py::init(
597 [](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); }))
598 .def("contains", &slot_dict_impl<Policy>::contains)
599 .def("items", &slot_dict_impl<Policy>::items)
600 .def("setattr", &slot_dict_impl<Policy>::setattr)
601 .def("getattr", &slot_dict_impl<Policy>::getattr);
602 }
603
604 private:
605 ModulePtr module_;
606 };
607
608 template <typename T>
debugMakeList(const T & list)609 py::list debugMakeList(const T& list) {
610 py::list result;
611 for (const auto& elem : list) {
612 result.append(py::cast(elem));
613 }
614 return result;
615 }
616 template <typename T>
debugMakeNamedList(const T & list)617 py::list debugMakeNamedList(const T& list) {
618 py::list result;
619 for (auto elem : list) {
620 result.append(py::cast(std::make_pair(elem.name, elem.value)));
621 }
622 return result;
623 }
624 template <typename T>
debugMakeSet(const T & list)625 py::set debugMakeSet(const T& list) {
626 py::set result;
627 for (const auto& elem : list) {
628 result.add(py::cast(elem));
629 }
630 return result;
631 }
632
_jit_debug_module_iterators(Module & module)633 static py::dict _jit_debug_module_iterators(Module& module) {
634 py::dict result;
635 result["children"] = debugMakeList(module.children());
636 result["named_children"] = debugMakeNamedList(module.named_children());
637 result["modules"] = debugMakeList(module.modules());
638 result["named_modules"] = debugMakeNamedList(module.named_modules());
639
640 result["parameters"] = debugMakeList(module.parameters(false));
641 result["named_parameters"] =
642 debugMakeNamedList(module.named_parameters(false));
643 result["parameters_r"] = debugMakeList(module.parameters(true));
644 result["named_parameters_r"] =
645 debugMakeNamedList(module.named_parameters(true));
646
647 result["buffers"] = debugMakeList(module.buffers(false));
648 result["named_buffers"] = debugMakeNamedList(module.named_buffers(false));
649 result["buffers_r"] = debugMakeList(module.buffers(true));
650 result["named_buffers_r"] = debugMakeNamedList(module.named_buffers(true));
651
652 result["named_attributes"] =
653 debugMakeNamedList(module.named_attributes(false));
654 result["named_attributes_r"] =
655 debugMakeNamedList(module.named_attributes(true));
656 return result;
657 }
658
659 static constexpr std::array<const char*, 48> magic_method_names = {
660 "__lt__", "__le__", "__eq__", "__ne__",
661 "__ge__", "__gt__", "__not__", "__abs__",
662 "__add__", "__and__", "__floordiv__", "__index__",
663 "__inv__", "__invert__", "__lshift__", "__mod__",
664 "__mul__", "__matmul__", "__neg__", "__or__",
665 "__pos__", "__pow__", "__rshift__", "__sub__",
666 "__truediv__", "__xor__", "__concat__", "__contains__",
667 "__delitem__", "__getitem__", "__setitem__", "__iadd__",
668 "__iand__", "__iconcat__", "__ifloordiv__", "__ilshift__",
669 "__imod__", "__imul__", "__imatmul__", "__ior__",
670 "__ipow__", "__irshift__", "__isub__", "__itruediv__",
671 "__ixor__", "__str__", "__len__", "__repr__",
672 };
673
674 struct DeepCopyMemoTable {
675 std::shared_ptr<IValue::HashIdentityIValueMap> map;
676 };
677
pyIValueDeepcopy(const IValue & ivalue,const py::dict & memo)678 IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) {
679 if (!memo.contains(py::str("__torch_script_memo_table"))) {
680 memo["__torch_script_memo_table"] =
681 DeepCopyMemoTable{std::make_shared<IValue::HashIdentityIValueMap>()};
682 }
683 auto& ivalue_memo =
684 *py::cast<DeepCopyMemoTable>(memo["__torch_script_memo_table"]).map;
685 return ivalue.deepcopy(ivalue_memo);
686 }
687
extra_files_from_python(const py::dict & pydict)688 ExtraFilesMap extra_files_from_python(const py::dict& pydict) {
689 ExtraFilesMap r;
690 for (const auto& it : pydict) {
691 r[py::cast<std::string>(it.first)] = "";
692 }
693 return r;
694 }
695
extra_files_to_python(const ExtraFilesMap & m,const py::dict & pydict)696 void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) {
697 // py::dict is pointer-like type so it gets modified despite const&
698 for (const auto& it : m) {
699 pydict[py::str(it.first)] = py::bytes(it.second);
700 }
701 }
702
pyCompilationUnitDefine(CompilationUnit & cu,const std::string & src,const ResolutionCallback * rcb,const uint32_t _frames_up)703 void pyCompilationUnitDefine(
704 CompilationUnit& cu,
705 const std::string& src,
706 const ResolutionCallback* rcb,
707 const uint32_t _frames_up) {
708 if (rcb && *rcb) {
709 cu.define(std::nullopt, src, pythonResolver(*rcb), nullptr);
710 } else {
711 py::object py_default_rcb =
712 py::module::import("torch._jit_internal")
713 .attr("createResolutionCallbackFromFrame")(_frames_up);
714 auto default_rcb = py_default_rcb.cast<ResolutionCallback>();
715 cu.define(std::nullopt, src, pythonResolver(default_rcb), nullptr);
716 }
717 }
718
719 // This function will copy bytes into a shared_ptr of chars aligned
720 // at kFlatbufferDataAlignmentBytes boundary (currently 16).
721 // This is required because tensors need to be aligned at 16 bytes boundary.
copyStr(const std::string & bytes)722 static std::shared_ptr<char> copyStr(const std::string& bytes) {
723 size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
724 kFlatbufferDataAlignmentBytes;
725 #ifdef _WIN32
726 std::shared_ptr<char> bytes_copy(
727 static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
728 _aligned_free);
729 #elif defined(__APPLE__)
730 void* p;
731 ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
732 TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
733 std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
734 #else
735 std::shared_ptr<char> bytes_copy(
736 static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
737 free);
738 #endif
739 memcpy(bytes_copy.get(), bytes.data(), bytes.size());
740 return bytes_copy;
741 }
742
initJitScriptBindings(PyObject * module)743 void initJitScriptBindings(PyObject* module) {
744 auto m = py::handle(module).cast<py::module>();
745
746 // NOLINTNEXTLINE(bugprone-unused-raii)
747 py::class_<c10::Capsule>(m, "Capsule");
748
749 auto object_class =
750 py::class_<Object>(m, "ScriptObject")
751 .def("_type", [](Object& o) { return o.type(); })
752 .def(
753 "_get_method",
754 [](Object& self, const std::string& name) -> Method {
755 return self.get_method(name);
756 },
757 py::keep_alive<0, 1>())
758 .def(
759 "setattr",
760 [](Object& self, const std::string& name, py::object value) {
761 if (self.type()->hasConstant(name)) {
762 TORCH_CHECK(
763 false,
764 "Can't set constant '",
765 name,
766 "' which has value:",
767 self.type()->getConstant(name));
768 }
769 TypePtr type = self.type()->getAttribute(name);
770 try {
771 auto ivalue = toIValue(std::move(value), type);
772 self.setattr(name, ivalue);
773 } catch (std::exception& e) {
774 throw py::cast_error(c10::str(
775 "Could not cast attribute '",
776 name,
777 "' to type ",
778 type->repr_str(),
779 ": ",
780 e.what()));
781 }
782 })
783 .def(
784 "getattr",
785 [](Object& self, const std::string& name) {
786 try {
787 return toPyObject(self.attr(name));
788 } catch (const ObjectAttributeError& err) {
789 throw AttributeError("%s", err.what());
790 }
791 })
792 .def(
793 "__getattr__",
794 [](Object& self, const std::string& name) -> py::object {
795 try {
796 if (name == "__qualname__") {
797 return py::cast(self.type()->name()->name());
798 }
799 if (auto method = self.find_method(name)) {
800 return py::cast(*method);
801 }
802 if (self.has_property(name)) {
803 auto prop = self.get_property(name);
804 // wrap the Method into callable PyObject
805 auto getter_func = py::cast(prop.getter_func);
806 return getter_func();
807 }
808 return toPyObject(self.attr(name));
809 } catch (const ObjectAttributeError& err) {
810 throw AttributeError("%s", err.what());
811 }
812 })
813 .def(
814 "__setattr__",
815 [](Object& self, const std::string& name, py::object value) {
816 try {
817 if (self.has_property(name)) {
818 auto prop = self.get_property(name);
819 if (!prop.setter_func.has_value()) {
820 TORCH_CHECK(false, "can't set attribute");
821 }
822 // wrap the Method into callable PyObject
823 auto setter_func = py::cast(prop.setter_func);
824 setter_func(value);
825 return;
826 }
827
828 if (self.type()->hasConstant(name)) {
829 TORCH_CHECK(
830 false,
831 "Can't set constant '",
832 name,
833 "' which has value:",
834 self.type()->getConstant(name));
835 }
836 TypePtr type = self.type()->getAttribute(name);
837 auto ivalue = toIValue(std::move(value), type);
838 self.setattr(name, ivalue);
839 } catch (const ObjectAttributeError& err) {
840 throw AttributeError("%s", err.what());
841 }
842 })
843 .def(
844 "hasattr",
845 [](Object& self, const std::string& name) {
846 return self.hasattr(name);
847 })
848 .def(
849 "_has_method",
850 [](Object& self, const std::string& name) {
851 return bool(self.find_method(name));
852 })
853 .def(
854 "_method_names",
855 [](Object& self) {
856 return fmap(self.get_methods(), [](const Method& method) {
857 return method.name();
858 });
859 })
860 .def(
861 "_properties", [](Object& self) { return self.get_properties(); })
862 .def("__copy__", &Object::copy)
863 .def(
864 "__hash__",
865 [](const Object& self) {
866 // Similar to Tensor's `__hash__`, which is `id()`.
867 return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
868 })
869 .def(py::pickle(
870 [](const Object& self)
871 -> std::tuple<py::object, std::string> { // __getstate__
872 if (auto getstate_method = self.find_method("__getstate__")) {
873 auto object_state = toPyObject((*getstate_method)(Stack{}));
874 TORCH_INTERNAL_ASSERT(self.type()->name());
875 return std::make_tuple(
876 object_state, self.type()->name()->qualifiedName());
877 }
878 std::stringstream err;
879 err << "Tried to serialize object ";
880 if (auto qualname = self.type()->name()) {
881 err << qualname->qualifiedName() << " ";
882 }
883 err << "which does not have a __getstate__ method defined!";
884 throw std::runtime_error(err.str());
885 },
886 [](const std::tuple<py::object, std::string>& state_tup)
887 -> Object {
888 auto [state, qualname] = state_tup;
889 auto class_type = getCustomClass(qualname);
890 TORCH_CHECK(
891 class_type,
892 "Tried to deserialize class ",
893 qualname,
894 " which is not known to the runtime. "
895 "If this is a custom C++ class, make "
896 "sure the appropriate code is linked.");
897
898 auto self = Object(c10::ivalue::Object::create(
899 c10::StrongTypePtr(
900 std::shared_ptr<torch::jit::CompilationUnit>(),
901 class_type),
902 1));
903 if (auto setstate_method = self.find_method("__setstate__")) {
904 auto setstate_schema =
905 setstate_method->function().getSchema();
906 TORCH_INTERNAL_ASSERT(
907 setstate_schema.arguments().size() == 2,
908 "__setstate__ method for class ",
909 class_type->repr_str(),
910 " must have exactly 2 arguments!");
911 auto state_type = setstate_schema.arguments().at(1).type();
912 (*setstate_method)(Stack{toIValue(state, state_type)});
913 return self;
914 }
915 std::stringstream err;
916 err << "Tried to deserialize object ";
917 if (auto qualname = class_type->name()) {
918 err << qualname->qualifiedName() << " ";
919 }
920 err << "which does not have a __setstate__ method defined!";
921 throw std::runtime_error(err.str());
922 }));
923
924 py::class_<Object::Property>(m, "ScriptObjectProperty")
925 .def_property_readonly(
926 "name", [](const Object::Property& self) { return self.name; })
927 .def_property_readonly(
928 "getter",
929 [](const Object::Property& self) { return self.getter_func; })
930 .def_property_readonly("setter", [](const Object::Property& self) {
931 return self.setter_func;
932 });
933
934 // Special case __str__ and __repr__ to make sure we can print Objects/Modules
935 // regardless of if the user defined __str__/__repr__
936 using MagicMethodImplType = std::function<py::object(
937 const Object& self, py::args args, py::kwargs kwargs)>;
938
939 std::unordered_map<std::string, MagicMethodImplType> special_magic_methods;
940 special_magic_methods.emplace(
941 "__str__",
942 [](const Object& self,
943 const py::args& args,
944 const py::kwargs& kwargs) -> py::object {
945 auto method = self.find_method("__str__");
946 if (!method) {
947 return py::str("ScriptObject <" + self.type()->str() + ">");
948 }
949 return invokeScriptMethodFromPython(*method, args, kwargs);
950 });
951
952 special_magic_methods.emplace(
953 "__repr__",
954 [](const Object& self,
955 const py::args& args,
956 const py::kwargs& kwargs) -> py::object {
957 auto method = self.find_method("__repr__");
958 if (!method) {
959 std::stringstream ss;
960 ss << std::hex << static_cast<const void*>(&self);
961 return py::str("<torch.ScriptObject object at " + ss.str() + ">");
962 }
963 return invokeScriptMethodFromPython(*method, args, kwargs);
964 });
965
966 for (const char* mm_name : magic_method_names) {
967 if (special_magic_methods.count(mm_name)) {
968 object_class.def(mm_name, special_magic_methods[mm_name]);
969 } else {
970 object_class.def(
971 mm_name,
972 [mm_name](
973 const Object& self,
974 const py::args& args,
975 const py::kwargs& kwargs) {
976 auto method = self.find_method(mm_name);
977 if (!method) {
978 std::string msg = fmt::format(
979 "'{}' is not implemented for {}",
980 mm_name,
981 self.type()->str());
982 throw c10::NotImplementedError(msg);
983 }
984 return invokeScriptMethodFromPython(*method, args, kwargs);
985 });
986 }
987 }
988
989 // NOLINTNEXTLINE(bugprone-unused-raii)
990 py::class_<DeepCopyMemoTable>(m, "DeepCopyMemoTable");
991
992 py::class_<UpgraderEntry>(m, "_UpgraderEntry")
993 .def(py::init<int, std::string, std::string>())
994 .def_property_readonly(
995 "bumped_at_version",
996 [](const UpgraderEntry& self) { return self.bumped_at_version; })
997 .def_property_readonly(
998 "upgrader_name",
999 [](const UpgraderEntry& self) { return self.upgrader_name; })
1000 .def_property_readonly("old_schema", [](const UpgraderEntry& self) {
1001 return self.old_schema;
1002 });
1003
1004 py::class_<UpgraderRange>(m, "_UpgraderRange")
1005 .def(py::init<int, int>())
1006 .def_property_readonly(
1007 "min_version",
1008 [](const UpgraderRange& self) { return self.min_version; })
1009 .def_property_readonly("max_version", [](const UpgraderRange& self) {
1010 return self.max_version;
1011 });
1012
1013 object_class.def(
1014 "__deepcopy__", [](const Object& self, const py::dict& memo) {
1015 return Object(
1016 pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1017 });
1018
1019 // Used by torch.package to save ScriptModule objects in unified format.
1020 py::class_<ScriptModuleSerializer>(m, "ScriptModuleSerializer")
1021 .def(py::init<caffe2::serialize::PyTorchStreamWriter&>())
1022 .def("serialize", &ScriptModuleSerializer::serialize_unified_format)
1023 .def(
1024 "write_files",
1025 &ScriptModuleSerializer::writeFiles,
1026 py::arg("code_dir") = ".data/ts_code/code/")
1027 .def(
1028 "storage_context",
1029 &ScriptModuleSerializer::storage_context,
1030 pybind11::return_value_policy::reference_internal);
1031
1032 // Used by torch.package to coordinate sharing of storages between eager
1033 // and ScriptModules.
1034 py::class_<
1035 SerializationStorageContext,
1036 std::shared_ptr<SerializationStorageContext>>(
1037 m, "SerializationStorageContext")
1038 .def("has_storage", &SerializationStorageContext::hasStorage)
1039 .def("get_or_add_storage", &SerializationStorageContext::getOrAddStorage);
1040
1041 // torch.jit.ScriptModule is a subclass of this C++ object.
1042 // Methods here are prefixed with _ since they should not be
1043 // public.
1044 py::class_<Module, Object>(m, "ScriptModule")
1045 .def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
1046 .def(
1047 "save",
1048 [](Module& m,
1049 const std::string& filename,
1050 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1051 m.save(filename, _extra_files);
1052 },
1053 py::arg("filename"),
1054 py::arg("_extra_files") = ExtraFilesMap())
1055 .def(
1056 "save_to_buffer",
1057 [](Module& m, const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1058 std::ostringstream buf;
1059 m.save(buf, _extra_files);
1060 return py::bytes(buf.str());
1061 },
1062 py::arg("_extra_files") = ExtraFilesMap())
1063 .def(
1064 "_save_for_mobile",
1065 [](Module& m,
1066 const std::string& filename,
1067 const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1068 bool _save_mobile_debug_info = false,
1069 bool _use_flatbuffer = false) {
1070 m._save_for_mobile(
1071 filename,
1072 _extra_files,
1073 _save_mobile_debug_info,
1074 _use_flatbuffer);
1075 },
1076 py::arg("filename"),
1077 py::arg("_extra_files") = ExtraFilesMap(),
1078 py::arg("_save_mobile_debug_info") = false,
1079 py::arg("_use_flatbuffer") = false)
1080 .def(
1081 "_save_to_buffer_for_mobile",
1082 [](Module& m,
1083 const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1084 bool _save_mobile_debug_info = false,
1085 bool _use_flatbuffer = false) {
1086 std::ostringstream buf;
1087 m._save_for_mobile(
1088 buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
1089 return py::bytes(buf.str());
1090 },
1091 py::arg("_extra_files") = ExtraFilesMap(),
1092 py::arg("_save_mobile_debug_info") = false,
1093 py::arg("_use_flatbuffer") = false)
1094 .def("_set_optimized", &Module::set_optimized)
1095 .def(
1096 "dump",
1097 &Module::dump,
1098 py::arg("code") = true,
1099 py::arg("attrs") = true,
1100 py::arg("params") = true)
1101 .def(
1102 "dump_to_str",
1103 &Module::dump_to_str,
1104 py::arg("code") = true,
1105 py::arg("attrs") = true,
1106 py::arg("params") = true)
1107 .def(
1108 "_replicate_for_data_parallel",
1109 [](Module& module) {
1110 const ModulePtr& obj = module._ivalue();
1111 auto copy = c10::ivalue::Object::create(
1112 c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
1113 obj->slots().size());
1114 for (size_t i = 0; i < obj->slots().size(); ++i) {
1115 copy->setSlot(i, obj->getSlot(i));
1116 }
1117 return Module(std::move(copy));
1118 })
1119 .def(
1120 "get_debug_state",
1121 [](Module& self) {
1122 if (auto m = self.find_method("forward")) {
1123 return m->get_executor().getDebugState();
1124 }
1125 throw std::runtime_error(
1126 "Attempted to call get_debug_state on a Module without a compiled forward()");
1127 })
1128 .def(
1129 "_define",
1130 [](Module& m,
1131 std::shared_ptr<ConcreteModuleType> concreteType,
1132 const std::string& script,
1133 const ResolutionCallback& rcb) {
1134 const auto self = ModuleSelf(std::move(concreteType));
1135 m._ivalue()->compilation_unit()->define(
1136 m.type()->name(), script, pythonResolver(rcb), &self);
1137 didFinishEmitModule(m);
1138 })
1139 .def(
1140 "_register_attribute",
1141 [](Module& m,
1142 const std::string& name,
1143 const TypePtr& type,
1144 py::handle value) {
1145 m.register_attribute(name, type, toIValue(value, type));
1146 })
1147 .def(
1148 "_create_method_from_trace",
1149 [](Module& self,
1150 const std::string& name,
1151 const py::function& func,
1152 const py::tuple& input_tuple,
1153 const py::function& var_name_lookup_fn,
1154 bool strict,
1155 bool force_outplace,
1156 const std::vector<std::string>& argument_names,
1157 bool store_inputs) {
1158 // prereq: Module's buffers and parameters are unique
1159 // this was ensured in python before calling this function
1160 auto typed_inputs = toTraceableStack(input_tuple);
1161
1162 std::shared_ptr<Graph> graph =
1163 std::get<0>(tracer::createGraphByTracing(
1164 func,
1165 typed_inputs,
1166 var_name_lookup_fn,
1167 strict,
1168 force_outplace,
1169 &self,
1170 argument_names));
1171 const auto method_name = QualifiedName(*self.type()->name(), name);
1172 auto fn = self._ivalue()->compilation_unit()->create_function(
1173 method_name, graph);
1174 self.type()->addMethod(fn);
1175 if (store_inputs) {
1176 self.store_traced_inputs(name, typed_inputs);
1177 }
1178 didFinishEmitModule(self);
1179 },
1180 py::arg("name"),
1181 py::arg("func"),
1182 py::arg("input_tuple"),
1183 py::arg("var_name_lookup_fn"),
1184 py::arg("strict"),
1185 py::arg("force_outplace"),
1186 py::arg("argument_names") = std::vector<std::string>(),
1187 py::arg("store_inputs"))
1188 .def(
1189 "_create_method_from_trace_with_dict",
1190 [](Module& self,
1191 const std::string& name,
1192 const py::function& func,
1193 const py::dict& input_dict,
1194 const py::function& var_name_lookup_fn,
1195 bool strict,
1196 bool force_outplace,
1197 const std::vector<std::string>& argument_names,
1198 bool store_inputs) {
1199 // prereq: Module's buffers and parameters are unique
1200 // this was ensured in python before calling this function
1201 auto typed_inputs = toTraceableStack(input_dict);
1202
1203 std::shared_ptr<Graph> graph =
1204 std::get<0>(tracer::createGraphByTracingWithDict(
1205 func,
1206 input_dict,
1207 typed_inputs,
1208 var_name_lookup_fn,
1209 strict,
1210 force_outplace,
1211 &self,
1212 argument_names));
1213 const auto method_name = QualifiedName(*self.type()->name(), name);
1214 auto fn = self._ivalue()->compilation_unit()->create_function(
1215 method_name, graph);
1216 if (store_inputs) {
1217 self.store_traced_inputs(name, typed_inputs);
1218 }
1219 self.type()->addMethod(fn);
1220 didFinishEmitModule(self);
1221 },
1222 py::arg("name"),
1223 py::arg("func"),
1224 py::arg("input_dict"),
1225 py::arg("var_name_lookup_fn"),
1226 py::arg("strict"),
1227 py::arg("force_outplace"),
1228 py::arg("argument_names") = std::vector<std::string>(),
1229 py::arg("store_inputs"))
1230 .def(
1231 "_get_forward_hooks",
1232 [](const Module& m) {
1233 std::vector<StrongFunctionPtr> funcs;
1234 for (auto& hook : m.type()->getForwardHooks()) {
1235 funcs.emplace_back(m.type()->compilation_unit(), hook);
1236 }
1237 return funcs;
1238 })
1239 .def(
1240 "_get_forward_pre_hooks",
1241 [](const Module& m) {
1242 std::vector<StrongFunctionPtr> funcs;
1243 for (auto& pre_hook : m.type()->getForwardPreHooks()) {
1244 funcs.emplace_back(m.type()->compilation_unit(), pre_hook);
1245 }
1246 return funcs;
1247 })
1248 .def(
1249 "_retrieve_traced_inputs",
1250 [](const Module& m) {
1251 return ScriptDict(m.retrieve_traced_inputs());
1252 })
1253 .def_property_readonly(
1254 "code",
1255 [](Module& self) {
1256 std::vector<at::IValue> constants;
1257 PrintDepsTable deps;
1258 PythonPrint pp(constants, deps);
1259 pp.printNamedType(self.type());
1260 return pp.str();
1261 })
1262 .def_property_readonly(
1263 "code_with_constants",
1264 [](Module& self) {
1265 std::vector<at::IValue> constants;
1266 PrintDepsTable deps;
1267 PythonPrint pp(constants, deps);
1268 pp.printNamedType(self.type());
1269 std::map<std::string, at::IValue> consts;
1270 int i = 0;
1271 for (auto const& constant : constants) {
1272 consts["c" + std::to_string(i)] = constant;
1273 i += 1;
1274 }
1275 return std::make_tuple(pp.str(), std::move(consts));
1276 })
1277 .def("apply", &Module::apply)
1278 .def("__copy__", &Module::copy)
1279 .def(
1280 "__hash__",
1281 [](const Module& self) {
1282 // Similar to Tensor's `__hash__`, which is `id()`.
1283 return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
1284 })
1285 .def(
1286 "__eq__",
1287 [](const Module& self, const py::object& other) {
1288 // TODO: call UDF if it exists
1289 if (!py::isinstance<Module>(other)) {
1290 return false;
1291 }
1292 return self._ivalue().get() ==
1293 py::cast<Module>(other)._ivalue().get();
1294 })
1295 .def(
1296 "__deepcopy__",
1297 [](const Module& self, const py::dict& memo) {
1298 return Module(
1299 pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1300 })
1301 .def("children", &Module::children)
1302 .def_property_readonly("qualified_name", [](const Module& self) {
1303 return self.type()->name()->qualifiedName();
1304 });
1305
1306 py::class_<mobile::Module>(m, "LiteScriptModule")
1307 .def(py::init<
1308 c10::intrusive_ptr<c10::ivalue::Object>,
1309 std::shared_ptr<mobile::CompilationUnit>>())
1310 .def(
1311 "find_method",
1312 [](mobile::Module& m, const std::string& method_name) {
1313 auto method = m.find_method(method_name);
1314 return method != std::nullopt;
1315 },
1316 py::arg("method_name"))
1317 .def(
1318 "run_method",
1319 [](mobile::Module& m,
1320 const std::string& method_name,
1321 const py::tuple& input_tuple) {
1322 Stack stack;
1323 for (auto& input : input_tuple) {
1324 stack.push_back(toTypeInferredIValue(input));
1325 }
1326 return m.get_method(method_name)(stack);
1327 },
1328 py::arg("method_name"),
1329 py::arg("input_tuple"))
1330 .def(
1331 "forward",
1332 [](mobile::Module& m, const py::tuple& input_tuple) {
1333 Stack stack;
1334 for (auto& input : input_tuple) {
1335 stack.push_back(toTypeInferredIValue(input));
1336 }
1337 return m.get_method("forward")(stack);
1338 },
1339 py::arg("input_tuple"));
1340
1341 slot_dict_impl<detail::ParameterPolicy>::bind(m, "ParameterDict");
1342 slot_dict_impl<detail::BufferPolicy>::bind(m, "BufferDict");
1343 slot_dict_impl<detail::ModulePolicy>::bind(m, "ModuleDict");
1344
1345 py::class_<ErrorReport, std::shared_ptr<ErrorReport>>(m, "ErrorReport")
1346 .def(py::init<SourceRange>())
1347 .def("what", &ErrorReport::what)
1348 .def_static("call_stack", ErrorReport::current_call_stack);
1349
1350 py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
1351 m, "CompilationUnit")
1352 .def(
1353 py::init([](const std::string& lang, const uint32_t _frames_up) {
1354 auto cu = std::make_shared<CompilationUnit>();
1355 if (!lang.empty()) {
1356 pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up);
1357 }
1358 return cu;
1359 }),
1360 py::arg("lang") = "",
1361 py::arg("_frames_up") = 0)
1362
1363 .def(
1364 "find_function",
1365 [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1366 auto fn = self->find_function(QualifiedName(name));
1367 if (fn) {
1368 return std::optional<StrongFunctionPtr>(
1369 StrongFunctionPtr(std::move(self), fn));
1370 } else {
1371 return std::optional<StrongFunctionPtr>(std::nullopt);
1372 }
1373 })
1374 .def(
1375 "__getattr__",
1376 [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1377 auto fn = self->find_function(QualifiedName(name));
1378 if (fn) {
1379 return StrongFunctionPtr(std::move(self), fn);
1380 } else {
1381 throw AttributeError(
1382 "'CompilationUnit' has no attribute '%s'", name.c_str());
1383 }
1384 })
1385 .def(
1386 "get_functions",
1387 [](const std::shared_ptr<CompilationUnit>& self) {
1388 auto raw_functions = self->get_functions();
1389 std::vector<StrongFunctionPtr> functions;
1390 functions.reserve(raw_functions.size());
1391 for (auto fn : raw_functions) {
1392 if (fn) {
1393 functions.emplace_back(self, fn);
1394 }
1395 }
1396 return functions;
1397 })
1398 .def("set_optimized", &CompilationUnit::set_optimized)
1399 .def(
1400 "define",
1401 pyCompilationUnitDefine,
1402 py::arg("src"),
1403 py::arg("rcb") = nullptr,
1404 py::arg("_frames_up") = 0)
1405 .def(
1406 "create_function",
1407 [](std::shared_ptr<CompilationUnit>& self,
1408 const std::string& qualified_name,
1409 std::shared_ptr<Graph> graph,
1410 bool should_mangle) {
1411 Function* fn = self->create_function(
1412 qualified_name, std::move(graph), should_mangle);
1413 return StrongFunctionPtr(std::move(self), fn);
1414 },
1415 py::arg("qualified_name"),
1416 py::arg("graph"),
1417 py::arg("should_mangle") = false)
1418 .def(
1419 "get_interface",
1420 [](const std::shared_ptr<CompilationUnit>& self,
1421 const std::string& name) { return self->get_interface(name); })
1422 .def(
1423 "get_class",
1424 [](const std::shared_ptr<CompilationUnit>& self,
1425 const std::string& name) { return self->get_class(name); })
1426 .def(
1427 "drop_all_functions",
1428 [](const std::shared_ptr<CompilationUnit>& self) {
1429 self->drop_all_functions();
1430 });
1431
1432 py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
1433 .def(
1434 "__call__",
1435 [](py::args args, const py::kwargs& kwargs) {
1436 HANDLE_TH_ERRORS
1437 // see: [pybind11 varargs]
1438 auto strongPtr = py::cast<StrongFunctionPtr>(args[0]);
1439 Function& callee = *strongPtr.function_;
1440 py::object result = invokeScriptFunctionFromPython(
1441 callee, tuple_slice(std::move(args), 1), kwargs);
1442 return result;
1443 END_HANDLE_TH_ERRORS_PYBIND
1444 })
1445 .def(
1446 "save",
1447 [](const StrongFunctionPtr& self,
1448 const std::string& filename,
1449 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1450 Module module("__torch__.PlaceholderModule");
1451 // [issue 27343]
1452 // Modules have 'training' attributes by default, but due to
1453 // https://github.com/pytorch/pytorch/issues/27343, functions end
1454 // up having a training attribute when they are loaded. This adds
1455 // a fake 'training' attribute that shouldn't be used, but prevents
1456 // jitter on saving and loading. Once that issue is fixed this can
1457 // be deleted.
1458 module.register_attribute("training", BoolType::get(), true);
1459 addFunctionToModule(module, self);
1460 module.save(filename, _extra_files);
1461 },
1462 py::arg("filename"),
1463 py::arg("_extra_files") = ExtraFilesMap())
1464 .def(
1465 "save_to_buffer",
1466 [](const StrongFunctionPtr& self,
1467 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1468 std::ostringstream buf;
1469 Module module("__torch__.PlaceholderModule");
1470 // see [issue 27343]
1471 module.register_attribute("training", BoolType::get(), true);
1472 addFunctionToModule(module, self);
1473 module.save(buf, _extra_files);
1474 return py::bytes(buf.str());
1475 },
1476 py::arg("_extra_files") = ExtraFilesMap())
1477 .def_property_readonly(
1478 "graph",
1479 [](const StrongFunctionPtr& self) {
1480 return toGraphFunction(*self.function_).graph();
1481 })
1482 .def_property_readonly(
1483 "inlined_graph",
1484 [](const StrongFunctionPtr& self) {
1485 auto g = toGraphFunction(*self.function_).graph()->copy();
1486 Inline(*g);
1487 return g;
1488 })
1489 .def_property_readonly(
1490 "schema",
1491 [](const StrongFunctionPtr& self) {
1492 return self.function_->getSchema();
1493 })
1494 .def_property_readonly(
1495 "code",
1496 [](const StrongFunctionPtr& self) {
1497 std::vector<at::IValue> constants;
1498 PrintDepsTable deps;
1499
1500 PythonPrint pp(constants, deps);
1501 pp.printFunction(*self.function_);
1502 return pp.str();
1503 })
1504 .def(
1505 "get_debug_state",
1506 [](const StrongFunctionPtr& self) {
1507 return toGraphFunction(*self.function_)
1508 .get_executor()
1509 .getDebugState();
1510 })
1511 .def(
1512 "_debug_flush_compilation_cache",
1513 [](const StrongFunctionPtr& self) {
1514 toGraphFunction(*self.function_)
1515 .get_executor()
1516 .debugFlushCompilationCache();
1517 })
1518 .def_property_readonly(
1519 "name",
1520 [](const StrongFunctionPtr& self) { return self.function_->name(); })
1521 .def(
1522 "_set_ignore_amp",
1523 [](StrongFunctionPtr& self, bool ignore) {
1524 auto fn = self.function_;
1525 TORCH_INTERNAL_ASSERT(fn->isGraphFunction());
1526 GraphFunction& g_fn = toGraphFunction(*fn);
1527 g_fn._set_ignore_amp(ignore);
1528 })
1529 .def_property_readonly(
1530 "qualified_name",
1531 [](const StrongFunctionPtr& self) {
1532 return self.function_->qualname().qualifiedName();
1533 })
1534 .def_property_readonly("__doc__", [](const StrongFunctionPtr& self) {
1535 return self.function_->doc_string();
1536 });
1537
1538 py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
1539 .def(
1540 "__call__",
1541 [](py::args args, const py::kwargs& kwargs) {
1542 // see: [pybind11 varargs]
1543 HANDLE_TH_ERRORS
1544 Method& method = py::cast<Method&>(args[0]);
1545
1546 return invokeScriptMethodFromPython(
1547 method, tuple_slice(std::move(args), 1), kwargs);
1548 END_HANDLE_TH_ERRORS_PYBIND
1549 })
1550 .def_property_readonly("graph", &Method::graph)
1551 .def_property_readonly(
1552 "inlined_graph",
1553 [](const Method& self) {
1554 auto g = toGraphFunction(self.function()).graph()->copy();
1555 Inline(*g);
1556 return g;
1557 })
1558 .def_property_readonly(
1559 "schema", [](Method& m) { return m.function().getSchema(); })
1560 .def_property_readonly("name", &Method::name)
1561 .def_property_readonly(
1562 "code",
1563 [](Method& self) {
1564 std::vector<at::IValue> constants;
1565 PrintDepsTable deps;
1566 PythonPrint pp(constants, deps);
1567 pp.printMethod(self.function());
1568 return pp.str();
1569 })
1570 .def(
1571 "_debug_flush_compilation_cache",
1572 [](Method& self) {
1573 return self.get_executor().debugFlushCompilationCache();
1574 })
1575 .def_property_readonly(
1576 "code_with_constants",
1577 [](Method& self) {
1578 std::vector<at::IValue> constants;
1579 PrintDepsTable deps;
1580 PythonPrint pp(constants, deps);
1581 pp.printMethod(self.function());
1582 std::map<std::string, at::IValue> consts;
1583 int i = 0;
1584 for (auto const& constant : constants) {
1585 consts["c" + std::to_string(i)] = constant;
1586 i += 1;
1587 }
1588 return std::make_tuple(pp.str(), std::move(consts));
1589 })
1590 .def_property_readonly("owner", &Method::owner)
1591 .def_property_readonly("raw_owner", [](const Method& self) {
1592 return Object(self.raw_owner());
1593 });
1594 m.def("_generate_upgraders_graph", &generate_upgraders_graph);
1595 m.def(
1596 "_calculate_package_version_based_on_upgraders",
1597 &calculate_package_version_based_on_upgraders);
1598 m.def("_get_version_calculator_flag", &get_version_calculator_flag);
1599 m.def(
1600 "_compile_graph_to_code_table",
1601 [](const std::string& name, const std::shared_ptr<Graph>& graph) {
1602 CompilationOptions options;
1603 GraphFunction jitFunc(name, graph, nullptr);
1604 auto mobileFunc = convertJitFunctionToMobileFunction(jitFunc, options);
1605 return convertMobileFunctionToCodeTable(*mobileFunc, options);
1606 });
1607 m.def(
1608 "_jit_script_compile",
1609 [](const std::string& qualname,
1610 const Def& def,
1611 const ResolutionCallback& rcb,
1612 const FunctionDefaults& defaults) {
1613 C10_LOG_API_USAGE_ONCE("torch.script.compile");
1614 const auto name = c10::QualifiedName(qualname);
1615 TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
1616 return script_compile_function(name, def, defaults, rcb);
1617 });
1618 m.def(
1619 "_jit_script_compile_overload",
1620 [](const std::string& qualname,
1621 const Decl& overload_decl,
1622 const Def& implementation_def,
1623 const ResolutionCallback& rcb,
1624 const FunctionDefaults& implementation_defaults,
1625 const py::object& signature) {
1626 const auto name = c10::QualifiedName(qualname);
1627 return script_compile_overloaded_function(
1628 name,
1629 overload_decl,
1630 implementation_def,
1631 rcb,
1632 implementation_defaults,
1633 signature);
1634 });
1635 m.def(
1636 "_replace_overloaded_method_decl",
1637 [](const Decl& overload_decl,
1638 const Def& implementation_def,
1639 const std::string& new_name) {
1640 checkOverloadDecl(overload_decl, implementation_def.decl());
1641 return implementation_def.withDecl(overload_decl).withName(new_name);
1642 });
1643 m.def(
1644 "_create_function_from_trace",
1645 [](const std::string& qualname,
1646 const py::function& func,
1647 const py::tuple& input_tuple,
1648 const py::function& var_name_lookup_fn,
1649 bool strict,
1650 bool force_outplace,
1651 const std::vector<std::string>& argument_names) {
1652 auto typed_inputs = toTraceableStack(input_tuple);
1653 std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
1654 func,
1655 typed_inputs,
1656 var_name_lookup_fn,
1657 strict,
1658 force_outplace,
1659 /*self=*/nullptr,
1660 argument_names));
1661
1662 auto cu = get_python_cu();
1663 auto name = c10::QualifiedName(qualname);
1664 auto result = cu->create_function(
1665 std::move(name), std::move(graph), /*shouldMangle=*/true);
1666 StrongFunctionPtr ret(std::move(cu), result);
1667 didFinishEmitFunction(ret);
1668 return ret;
1669 },
1670 py::arg("name"),
1671 py::arg("func"),
1672 py::arg("input_tuple"),
1673 py::arg("var_name_lookup_fn"),
1674 py::arg("strict"),
1675 py::arg("force_outplace"),
1676 py::arg("argument_names") = std::vector<std::string>());
1677
1678 m.def(
1679 "_create_function_from_trace_with_dict",
1680 [](const std::string& qualname,
1681 const py::function& func,
1682 const py::dict& input_dict,
1683 const py::function& var_name_lookup_fn,
1684 bool strict,
1685 bool force_outplace,
1686 const std::vector<std::string>& argument_names) {
1687 auto typed_inputs = toTraceableStack(input_dict);
1688 std::shared_ptr<Graph> graph =
1689 std::get<0>(tracer::createGraphByTracingWithDict(
1690 func,
1691 input_dict,
1692 typed_inputs,
1693 var_name_lookup_fn,
1694 strict,
1695 force_outplace,
1696 /*self=*/nullptr,
1697 argument_names));
1698
1699 auto cu = get_python_cu();
1700 auto name = c10::QualifiedName(qualname);
1701 auto result = cu->create_function(
1702 std::move(name), std::move(graph), /*shouldMangle=*/true);
1703 StrongFunctionPtr ret(std::move(cu), result);
1704 didFinishEmitFunction(ret);
1705 return ret;
1706 },
1707 py::arg("name"),
1708 py::arg("func"),
1709 py::arg("input_dict"),
1710 py::arg("var_name_lookup_fn"),
1711 py::arg("strict"),
1712 py::arg("force_outplace"),
1713 py::arg("argument_names") = std::vector<std::string>());
1714
1715 m.def(
1716 "_jit_script_class_compile",
1717 [](const std::string& qualifiedName,
1718 const ClassDef& classDef,
1719 const ClassMethodDefaults& defaults,
1720 const ResolutionCallback& rcb) {
1721 C10_LOG_API_USAGE_ONCE("torch.script.class");
1722 if (classDef.superclass().present()) {
1723 throw(
1724 ErrorReport(classDef.range())
1725 << "Torchscript does not support class inheritance.");
1726 }
1727 auto cu = get_python_cu();
1728 auto classname = c10::QualifiedName(qualifiedName);
1729 if (cu->get_type(classname) != nullptr) {
1730 classname = cu->mangle(classname);
1731 }
1732
1733 auto classType = ClassType::create(
1734 classname,
1735 cu,
1736 /* is_module = */ false,
1737 /* doc_string = */ "",
1738 getUnresolvedClassAttributes(classDef));
1739 cu->register_type(classType);
1740 std::vector<ResolverPtr> methodRcbs, propRcbs;
1741 std::vector<Def> methodDefs;
1742 std::vector<Property> props;
1743
1744 for (const auto& def : classDef.body()) {
1745 if (def.kind() != TK_DEF) {
1746 throw(
1747 ErrorReport(def.range())
1748 << "Currently class bodies can only contain method "
1749 "definitions. File an issue on GitHub if you want "
1750 "something else!");
1751 }
1752 methodDefs.emplace_back(def);
1753 methodRcbs.push_back(
1754 pythonResolver(rcb, classDef.name().name(), classType));
1755 }
1756
1757 // Gather definitions for property getters and setters as well as
1758 // corresponding resolution callbacks.
1759 if (classDef.properties().present()) {
1760 for (const auto& prop : classDef.properties().get()) {
1761 props.emplace_back(prop);
1762 propRcbs.push_back(
1763 pythonResolver(rcb, classDef.name().name(), classType));
1764 }
1765 }
1766
1767 const auto self = SimpleSelf(classType);
1768 cu->define(classname, props, propRcbs, methodDefs, methodRcbs, &self);
1769
1770 // Stitch in default arguments for methods. Properties don't need to be
1771 // considered since there is no way to invoke setters without passing in
1772 // a value.
1773 auto defs_it = methodDefs.begin();
1774 while (defs_it != methodDefs.end()) {
1775 auto def_name = (*defs_it).name().name();
1776 // If the method is not in the defaults map, assume there are
1777 // no default arguments for it.
1778 auto default_it = defaults.find(def_name);
1779 if (default_it == defaults.end()) {
1780 continue;
1781 }
1782
1783 const auto method_name =
1784 QualifiedName(classname, (*defs_it).name().name());
1785 auto& method = cu->get_function(method_name);
1786 method.setSchema(getSchemaWithNameAndDefaults(
1787 defs_it->range(),
1788 method.getSchema(),
1789 std::nullopt,
1790 default_it->second));
1791 ++defs_it;
1792 }
1793 return classType;
1794 });
1795 m.def(
1796 "_jit_script_interface_compile",
1797 [](const std::string& qualifiedName,
1798 const ClassDef& classDef,
1799 const ResolutionCallback& rcb,
1800 bool is_module) {
1801 auto cu = get_python_cu();
1802 auto className = c10::QualifiedName(qualifiedName);
1803 if (cu->get_type(className) != nullptr) {
1804 className = cu->mangle(className);
1805 }
1806
1807 get_python_cu()->define_interface(
1808 className, classDef, pythonResolver(rcb), is_module);
1809 return className.qualifiedName();
1810 });
1811
1812 py::class_<torch::jit::ErrorReport::CallStack>(
1813 m, "CallStack", py::dynamic_attr())
1814 .def(py::init<const std::string&, const SourceRange&>());
1815
1816 m.def("_parse_source_def", [](const std::string& src) {
1817 Parser p(std::make_shared<Source>(src));
1818 return Def(p.parseFunction(/*is_method=*/true));
1819 });
1820 m.def("parse_type_comment", [](const std::string& comment) {
1821 Parser p(std::make_shared<Source>(comment));
1822 return Decl(p.parseTypeComment());
1823 });
1824
1825 m.def("_get_upgraders_map_size", &get_upgraders_map_size);
1826 m.def("_dump_upgraders_map", &dump_upgraders_map);
1827
1828 m.def("_test_only_populate_upgraders", &test_only_populate_upgraders);
1829 m.def("_test_only_remove_upgraders", &test_only_remove_upgraders);
1830
1831 m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
1832 m.def("_get_max_operator_version", &getMaxOperatorVersion);
1833 m.def("_get_operator_version_map", &get_operator_version_map);
1834 m.def("_get_upgraders_entry_map", &get_upgraders_entry_map);
1835 m.def("_get_upgrader_ranges", &getUpgradersRangeForOp);
1836 m.def("_test_only_add_entry_to_op_version_map", &test_only_add_entry);
1837 m.def("_test_only_remove_entry_to_op_version_map", &test_only_remove_entry);
1838 m.def(
1839 "import_ir_module",
1840 [](std::shared_ptr<CompilationUnit> cu,
1841 const std::string& filename,
1842 py::object map_location,
1843 const py::dict& extra_files,
1844 bool restore_shapes = false) {
1845 std::optional<at::Device> optional_device;
1846 if (!map_location.is_none()) {
1847 AT_ASSERT(THPDevice_Check(map_location.ptr()));
1848 optional_device =
1849 reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1850 }
1851 ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1852 auto ret = import_ir_module(
1853 std::move(cu),
1854 filename,
1855 optional_device,
1856 extra_files_map,
1857 /*load_debug_files*/ true,
1858 restore_shapes);
1859 extra_files_to_python(extra_files_map, extra_files);
1860 return ret;
1861 });
1862 m.def(
1863 "_import_ir_module_from_package",
1864 [](std::shared_ptr<CompilationUnit> cu,
1865 std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
1866 std::shared_ptr<torch::jit::DeserializationStorageContext>
1867 storage_context,
1868 py::object map_location,
1869 const std::string& ts_id) {
1870 std::optional<at::Device> optional_device;
1871 if (!map_location.is_none()) {
1872 AT_ASSERT(THPDevice_Check(map_location.ptr()));
1873 optional_device =
1874 reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1875 }
1876 return import_ir_module(
1877 std::move(cu),
1878 std::move(reader),
1879 std::move(storage_context),
1880 optional_device,
1881 ts_id);
1882 });
1883 m.def(
1884 "import_ir_module_from_buffer",
1885 [](std::shared_ptr<CompilationUnit> cu,
1886 const std::string& buffer,
1887 py::object map_location,
1888 const py::dict& extra_files,
1889 bool restore_shapes = false) {
1890 std::istringstream in(buffer);
1891 std::optional<at::Device> optional_device;
1892 if (!map_location.is_none()) {
1893 AT_ASSERT(THPDevice_Check(map_location.ptr()));
1894 optional_device =
1895 reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1896 }
1897 ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1898 auto ret = import_ir_module(
1899 std::move(cu),
1900 in,
1901 optional_device,
1902 extra_files_map,
1903 /*load_debug_files*/ true,
1904 restore_shapes);
1905 extra_files_to_python(extra_files_map, extra_files);
1906 return ret;
1907 });
1908 m.def(
1909 "_load_for_lite_interpreter",
1910 [](const std::string& filename, py::object map_location) {
1911 std::optional<at::Device> optional_device;
1912 if (!map_location.is_none()) {
1913 AT_ASSERT(THPDevice_Check(map_location.ptr()));
1914 optional_device =
1915 reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1916 }
1917 return _load_for_mobile(filename, optional_device);
1918 });
1919 m.def(
1920 "_load_for_lite_interpreter_from_buffer",
1921 [](const std::string& buffer, py::object map_location) {
1922 std::istringstream in(buffer);
1923 std::optional<at::Device> optional_device;
1924 if (!map_location.is_none()) {
1925 AT_ASSERT(THPDevice_Check(map_location.ptr()));
1926 optional_device =
1927 reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1928 }
1929 return _load_for_mobile(in, optional_device);
1930 });
1931 m.def(
1932 "_backport_for_mobile",
1933 [](const std::string& filename_input,
1934 const std::string& filename_output,
1935 const int64_t version) {
1936 return _backport_for_mobile(filename_input, filename_output, version);
1937 });
1938 m.def(
1939 "_backport_for_mobile_from_buffer",
1940 [](const std::string& buffer_input,
1941 const std::string& filename_output,
1942 const int64_t version) {
1943 std::istringstream in(buffer_input);
1944 return _backport_for_mobile(in, filename_output, version);
1945 });
1946 m.def(
1947 "_backport_for_mobile_to_buffer",
1948 [](const std::string& filename_input, const int64_t version) {
1949 std::ostringstream buffer_output;
1950 bool success =
1951 _backport_for_mobile(filename_input, buffer_output, version);
1952 return success ? py::bytes(buffer_output.str()) : py::bytes("");
1953 });
1954 m.def(
1955 "_backport_for_mobile_from_buffer_to_buffer",
1956 [](const std::string& buffer_input, const int64_t version) {
1957 std::istringstream in(buffer_input);
1958 std::ostringstream buffer_output;
1959 bool success = _backport_for_mobile(in, buffer_output, version);
1960 return success ? py::bytes(buffer_output.str()) : py::bytes("");
1961 });
1962 m.def("_get_model_bytecode_version", [](const std::string& filename) {
1963 return _get_model_bytecode_version(filename);
1964 });
1965 m.def(
1966 "_get_model_extra_files",
1967 [](const std::string& filename, const py::dict& py_extra_files) {
1968 std::optional<at::Device> optional_device;
1969 ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1970 _load_for_mobile(filename, optional_device, cpp_extra_files);
1971 extra_files_to_python(cpp_extra_files, py_extra_files);
1972
1973 return py_extra_files;
1974 });
1975 m.def(
1976 "_get_model_bytecode_version_from_buffer", [](const std::string& buffer) {
1977 std::istringstream in(buffer);
1978 return _get_model_bytecode_version(in);
1979 });
1980 m.def(
1981 "_get_model_extra_files_from_buffer",
1982 [](const std::string& buffer, const py::dict& py_extra_files) {
1983 std::optional<at::Device> optional_device;
1984 ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1985 std::istringstream in(buffer);
1986 _load_for_mobile(in, optional_device, cpp_extra_files);
1987 extra_files_to_python(cpp_extra_files, py_extra_files);
1988
1989 return py_extra_files;
1990 });
1991 m.def("_get_mobile_model_contained_types", [](const std::string& filename) {
1992 return _get_mobile_model_contained_types(filename);
1993 });
1994 m.def(
1995 "_get_mobile_model_contained_types_from_buffer",
1996 [](const std::string& buffer) {
1997 std::istringstream in(buffer);
1998 return _get_mobile_model_contained_types(in);
1999 });
2000 m.def("_nn_module_to_mobile", [](const Module& module) {
2001 CompilationOptions options;
2002 return jitModuleToMobile(module, options);
2003 });
2004 py::class_<OperatorInfo>(m, "OperatorInfo")
2005 .def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
2006 m.def("_get_model_ops_and_info", [](const std::string& filename) {
2007 return _get_model_ops_and_info(filename);
2008 });
2009 m.def("_get_model_ops_and_info_from_buffer", [](const std::string& buffer) {
2010 std::istringstream in(buffer);
2011 return _get_model_ops_and_info(in);
2012 });
2013 m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
2014 return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
2015 });
2016 m.def(
2017 "_quantize_ondevice_ptq_dynamic",
2018 [](mobile::Module& m, const std::string& method_name) {
2019 mobile::quantization::PTQQuanizationHelper ptq_helper;
2020 ptq_helper.quantize_dynamic(m, method_name);
2021 });
2022
2023 m.def("_jit_set_emit_hooks", setEmitHooks);
2024 m.def("_jit_get_emit_hooks", getEmitHooks);
2025 m.def("_jit_clear_class_registry", []() {
2026 get_python_cu()->_clear_python_cu();
2027 });
2028 m.def(
2029 "_debug_set_autodiff_subgraph_inlining",
2030 debugSetAutodiffSubgraphInlining);
2031 m.def("_debug_set_fusion_group_inlining", debugSetFusionGroupInlining);
2032 m.def("_debug_get_fusion_group_inlining", getFusionGroupInlining);
2033 m.def("_propagate_shapes", _propagate_shapes);
2034 m.def(
2035 "_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes);
2036 m.def(
2037 "_last_executed_optimized_graph",
2038 []() { return lastExecutedOptimizedGraph(); },
2039 "Retrieve the optimized graph that was run the last time the graph executor ran on this thread");
2040 m.def(
2041 "_create_function_from_graph",
2042 [](const std::string& qualname, std::shared_ptr<Graph> graph) {
2043 // TODO this should go in the global Python CU
2044 auto cu = std::make_shared<CompilationUnit>();
2045 c10::QualifiedName name(qualname);
2046 auto fn = cu->create_function(std::move(name), std::move(graph));
2047 return StrongFunctionPtr(std::move(cu), fn);
2048 });
2049 m.def("_ivalue_tags_match", ivalue_tags_match);
2050 m.def("_ivalue_debug_python_object", [](py::object py_obj) {
2051 // convert to IValue first, IValue will incref via py::object
2052 IValue pyobj_ivalue = toIValue(std::move(py_obj), PyObjectType::get());
2053 // convert back to PyObject by borrowing the reference, which also
2054 // incref, after the return of this function, IValue is out of scope
2055 // which decref, so the return value is original refcount + 1
2056 py::object ret = toPyObject(pyobj_ivalue);
2057 return ret;
2058 });
2059 m.def("_jit_debug_module_iterators", _jit_debug_module_iterators);
2060
2061 py::class_<testing::FileCheck>(m, "FileCheck")
2062 .def(py::init<>())
2063 .def("check", &testing::FileCheck::check)
2064 .def("check_not", &testing::FileCheck::check_not)
2065 .def("check_same", &testing::FileCheck::check_same)
2066 .def("check_next", &testing::FileCheck::check_next)
2067 .def("check_count", &testing::FileCheck::check_count)
2068 .def("check_dag", &testing::FileCheck::check_dag)
2069 .def(
2070 "check_source_highlighted",
2071 &testing::FileCheck::check_source_highlighted)
2072 .def("check_regex", &testing::FileCheck::check_regex)
2073 .def(
2074 "check_count",
2075 [](testing::FileCheck& f,
2076 const std::string& str,
2077 size_t count,
2078 bool exactly) { return f.check_count(str, count, exactly); },
2079 "Check Count",
2080 py::arg("str"),
2081 py::arg("count"),
2082 py::arg("exactly") = false)
2083 .def(
2084 "run",
2085 [](testing::FileCheck& f, const std::string& str) {
2086 return f.run(str);
2087 })
2088 .def(
2089 "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); })
2090 .def(
2091 "run",
2092 [](testing::FileCheck& f,
2093 const std::string& input,
2094 const std::string& output) { return f.run(input, output); },
2095 "Run",
2096 py::arg("checks_file"),
2097 py::arg("test_file"))
2098 .def(
2099 "run",
2100 [](testing::FileCheck& f, const std::string& input, const Graph& g) {
2101 return f.run(input, g);
2102 },
2103 "Run",
2104 py::arg("checks_file"),
2105 py::arg("graph"));
2106
2107 m.def(
2108 "_logging_set_logger",
2109 [](logging::LoggerBase* logger) { return logging::setLogger(logger); },
2110 py::return_value_policy::reference);
2111 m.def("_set_graph_executor_optimize", [](bool optimize) {
2112 setGraphExecutorOptimize(optimize);
2113 });
2114
2115 m.def(
2116 "_get_graph_executor_optimize",
2117 [](std::optional<bool> new_setting = std::nullopt) {
2118 bool old_value = getGraphExecutorOptimize();
2119 if (new_setting) {
2120 setGraphExecutorOptimize(*new_setting);
2121 }
2122 return old_value;
2123 },
2124 py::arg("new_settings") = nullptr);
2125
2126 m.def(
2127 "_enable_mobile_interface_call_export",
2128 &torch::jit::enableMobileInterfaceCallExport);
2129
2130 m.def("_create_module_with_type", [](const ClassTypePtr& type) {
2131 return Module(get_python_cu(), type);
2132 }).def("_create_object_with_type", [](const ClassTypePtr& type) {
2133 return Object(get_python_cu(), type);
2134 });
2135
2136 m.def("_export_opnames", [](Module& sm) {
2137 return debugMakeList(torch::jit::export_opnames(sm));
2138 });
2139
2140 py::class_<
2141 ConcreteModuleTypeBuilder,
2142 std::shared_ptr<ConcreteModuleTypeBuilder>>(
2143 m, "ConcreteModuleTypeBuilder")
2144 .def(py::init<py::object>())
2145 .def(
2146 "add_constant",
2147 [](ConcreteModuleTypeBuilder& self,
2148 std::string name,
2149 py::object value) {
2150 self.addConstant(std::move(name), std::move(value));
2151 })
2152 .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
2153 .def(
2154 "add_function_attribute",
2155 &ConcreteModuleTypeBuilder::addFunctionAttribute)
2156 .def(
2157 "add_builtin_function",
2158 &ConcreteModuleTypeBuilder::addBuiltinFunction)
2159 .def("add_forward_hook", &ConcreteModuleTypeBuilder::addForwardHook)
2160 .def(
2161 "add_forward_pre_hook", &ConcreteModuleTypeBuilder::addForwardPreHook)
2162 .def("add_module", &ConcreteModuleTypeBuilder::addModule)
2163 .def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
2164 .def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
2165 .def(
2166 "add_failed_attribute",
2167 &ConcreteModuleTypeBuilder::addFailedAttribute)
2168 .def(
2169 "add_ignored_attribute",
2170 &ConcreteModuleTypeBuilder::addIgnoredAttribute)
2171 .def(
2172 "add_ignored_attributes",
2173 [](ConcreteModuleTypeBuilder& self,
2174 const std::vector<std::string>& names) {
2175 for (auto& name : names) {
2176 self.addIgnoredAttribute(name);
2177 }
2178 })
2179 .def(
2180 "set_module_dict",
2181 [](ConcreteModuleTypeBuilder& self) {
2182 self.setIterableModuleKind(IterableModuleKind::DICT);
2183 })
2184 .def("build", &ConcreteModuleTypeBuilder::build)
2185 .def(
2186 "equals",
2187 [](const ConcreteModuleTypeBuilder& self,
2188 const ConcreteModuleTypeBuilder& other) {
2189 return self.equals(other);
2190 })
2191 .def(
2192 "set_module_list",
2193 [](ConcreteModuleTypeBuilder& self) {
2194 self.setIterableModuleKind(IterableModuleKind::LIST);
2195 })
2196 .def(
2197 "set_parameter_list",
2198 [](ConcreteModuleTypeBuilder& self) {
2199 self.setIterableModuleKind(IterableModuleKind::PARAMLIST);
2200 })
2201 .def("set_parameter_dict", [](ConcreteModuleTypeBuilder& self) {
2202 self.setIterableModuleKind(IterableModuleKind::PARAMDICT);
2203 });
2204
2205 py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
2206 m, "ConcreteModuleType")
2207 .def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
2208 .def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
2209 .def_static("from_jit_type", &ConcreteModuleType::fromJitType)
2210 .def("get_constants", &ConcreteModuleType::getConstantsPy)
2211 .def("get_attributes", &ConcreteModuleType::getAttributesPy)
2212 .def("get_modules", &ConcreteModuleType::getModulesPy)
2213 .def("dump", &ConcreteModuleType::dump)
2214 .def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute)
2215 .def(
2216 "equals",
2217 [](const ConcreteModuleType& self, const ConcreteModuleType& other) {
2218 return self.equals(other);
2219 })
2220 .def(
2221 "equals",
2222 [](const ConcreteModuleType& self,
2223 const ConcreteModuleTypeBuilder& other) {
2224 return self.equals(other);
2225 })
2226 .def(
2227 "_create_methods_and_properties",
2228 [](std::shared_ptr<ConcreteModuleType> concreteType,
2229 const std::vector<Property>& properties,
2230 const std::vector<ResolutionCallback>& propertyRcbs,
2231 const std::vector<Def>& methodDefs,
2232 const std::vector<ResolutionCallback>& methodRcbs,
2233 const std::vector<FunctionDefaults>& defaults) {
2234 TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
2235 TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
2236
2237 std::vector<ResolverPtr> methodResolvers, propertyResolvers;
2238 methodResolvers.reserve(methodRcbs.size());
2239 for (auto& callback : methodRcbs) {
2240 methodResolvers.push_back(pythonResolver(callback));
2241 }
2242
2243 propertyResolvers.reserve(propertyRcbs.size());
2244 for (auto& callback : propertyRcbs) {
2245 propertyResolvers.push_back(pythonResolver(callback));
2246 }
2247
2248 const auto& selfType =
2249 concreteType->getJitType()->expect<ClassType>();
2250 const auto& prefix = selfType->name().value();
2251 const auto self = ModuleSelf(std::move(concreteType));
2252 auto cu = selfType->compilation_unit();
2253 cu->define(
2254 prefix,
2255 properties,
2256 propertyResolvers,
2257 methodDefs,
2258 methodResolvers,
2259 &self);
2260 // Stitch in default arguments for each Def if provided
2261 auto defaults_it = defaults.begin();
2262 auto defs_it = methodDefs.begin();
2263 while (defs_it != methodDefs.end()) {
2264 const auto method_name =
2265 QualifiedName(prefix, (*defs_it).name().name());
2266 auto& method = cu->get_function(method_name);
2267 method.setSchema(getSchemaWithNameAndDefaults(
2268 defs_it->range(),
2269 method.getSchema(),
2270 std::nullopt,
2271 *defaults_it));
2272 ++defs_it;
2273 ++defaults_it;
2274 }
2275 })
2276 .def(
2277 "_create_hooks",
2278 [](std::shared_ptr<ConcreteModuleType> concreteType,
2279 const std::vector<Def>& hookDefs,
2280 const std::vector<ResolutionCallback>& hookRcbs,
2281 const std::vector<Def>& preHookDefs,
2282 const std::vector<ResolutionCallback>& preHookRcbs) {
2283 TORCH_INTERNAL_ASSERT(hookDefs.size() == hookRcbs.size());
2284 TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookRcbs.size());
2285
2286 std::vector<ResolverPtr> hookResolvers, preHookResolvers;
2287
2288 hookResolvers.reserve(hookRcbs.size());
2289 for (auto& callback : hookRcbs) {
2290 hookResolvers.push_back(pythonResolver(callback));
2291 }
2292
2293 preHookResolvers.reserve(preHookRcbs.size());
2294 for (auto& callback : preHookRcbs) {
2295 preHookResolvers.push_back(pythonResolver(callback));
2296 }
2297
2298 const auto& selfType =
2299 concreteType->getJitType()->expect<ClassType>();
2300 const auto& prefix = selfType->name().value();
2301 const auto self = ModuleSelf(std::move(concreteType));
2302 auto cu = selfType->compilation_unit();
2303 cu->define_hooks(
2304 prefix,
2305 hookDefs,
2306 hookResolvers,
2307 preHookDefs,
2308 preHookResolvers,
2309 &self);
2310 });
2311
2312 m.def(
2313 "_resolve_type",
2314 [](const std::string& name,
2315 const SourceRange& range,
2316 const ResolutionCallback& rcb) {
2317 return pythonResolver(rcb)->resolveType(name, range);
2318 });
2319 m.def(
2320 "_resolve_type_from_object",
2321 [](const py::object& obj,
2322 const SourceRange& range,
2323 const ResolutionCallback& rcb) {
2324 return pythonResolver(rcb)->resolveTypeFromObject(obj, range);
2325 });
2326
2327 m.def(
2328 "_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
2329
2330 m.def(
2331 "_set_should_use_format_with_string_table",
2332 setShouldUseFormatWithStringTable);
2333
2334 // NOLINTNEXTLINE(bugprone-unused-raii)
2335 py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
2336 m, "LoggerBase");
2337 py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
2338 .value("SUM", logging::LockingLogger::AggregationType::SUM)
2339 .value("AVG", logging::LockingLogger::AggregationType::AVG)
2340 .export_values();
2341 py::class_<
2342 logging::LockingLogger,
2343 logging::LoggerBase,
2344 std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
2345 .def(py::init<>())
2346 .def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
2347 .def("get_counter_val", &logging::LockingLogger::getCounterValue);
2348 py::class_<
2349 logging::NoopLogger,
2350 logging::LoggerBase,
2351 std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
2352 .def(py::init<>());
2353 m.def("_jit_is_script_object", [](const py::object& obj) {
2354 return py::isinstance<Object>(obj);
2355 });
2356
2357 m.def("_get_file_format", [](const std::string& path) {
2358 switch (getFileFormat(path)) {
2359 case FileFormat::FlatbufferFileFormat:
2360 return "flatbuffer";
2361 case FileFormat::ZipFileFormat:
2362 return "zipfile";
2363 default:
2364 return "invalid";
2365 }
2366 });
2367
2368 m.def(
2369 "_save_parameters",
2370 [](const std::map<std::string, at::Tensor>& map,
2371 const std::string& filename,
2372 bool use_flatbuffer = false) {
2373 _save_parameters(map, filename, use_flatbuffer);
2374 });
2375
2376 m.def("_load_mobile_module_from_file", [](const std::string& filename) {
2377 return torch::jit::load_mobile_module_from_file(filename);
2378 });
2379 m.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
2380 auto bytes_copy = copyStr(bytes);
2381 return torch::jit::parse_and_initialize_mobile_module(
2382 bytes_copy, bytes.size());
2383 });
2384 m.def("_load_jit_module_from_file", [](const std::string& filename) {
2385 ExtraFilesMap extra_files = ExtraFilesMap();
2386 return torch::jit::load_jit_module_from_file(filename, extra_files);
2387 });
2388 m.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
2389 auto bytes_copy = copyStr(bytes);
2390 ExtraFilesMap extra_files = ExtraFilesMap();
2391 return torch::jit::parse_and_initialize_jit_module(
2392 bytes_copy, bytes.size(), extra_files);
2393 });
2394 m.def(
2395 "_save_mobile_module",
2396 [](const torch::jit::mobile::Module& module,
2397 const std::string& filename,
2398 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2399 return torch::jit::save_mobile_module(module, filename, _extra_files);
2400 });
2401 m.def(
2402 "_save_jit_module",
2403 [](const torch::jit::Module& module,
2404 const std::string& filename,
2405 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2406 return torch::jit::save_jit_module(module, filename, _extra_files);
2407 });
2408 m.def(
2409 "_save_mobile_module_to_bytes",
2410 [](const torch::jit::mobile::Module& module,
2411 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2412 auto detached_buffer =
2413 torch::jit::save_mobile_module_to_bytes(module, _extra_files);
2414 return py::bytes(
2415 reinterpret_cast<char*>(detached_buffer->data()),
2416 detached_buffer->size());
2417 });
2418 m.def(
2419 "_save_jit_module_to_bytes",
2420 [](const torch::jit::Module& module,
2421 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2422 auto detached_buffer =
2423 torch::jit::save_jit_module_to_bytes(module, _extra_files);
2424 return py::bytes(
2425 reinterpret_cast<char*>(detached_buffer->data()),
2426 detached_buffer->size());
2427 });
2428 m.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
2429 py::gil_scoped_acquire acquire;
2430 py::dict result;
2431 mobile::ModuleInfo minfo =
2432 torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
2433 result["bytecode_version"] = minfo.bytecode_version;
2434 result["operator_version"] = minfo.operator_version;
2435 result["function_names"] = minfo.function_names;
2436 result["type_names"] = minfo.type_names;
2437 result["opname_to_num_args"] = minfo.opname_to_num_args;
2438 return result;
2439 });
2440
2441 m.def("_pickle_save", [](const IValue& v) {
2442 auto bytes = torch::jit::pickle_save(v);
2443 return py::bytes(bytes.data(), bytes.size());
2444 });
2445
2446 m.def("_pickle_load_obj", [](const py::bytes& bytes) {
2447 // https://github.com/pybind/pybind11/issues/2517
2448 std::string buffer = bytes;
2449 return torch::jit::pickle_load_obj(buffer);
2450 });
2451
2452 initScriptDictBindings(module);
2453 initScriptListBindings(module);
2454 }
2455
2456 } // namespace torch::jit
2457