1 #include <torch/csrc/jit/serialization/import_source.h>
2
3 #include <ATen/core/ivalue_inl.h>
4 #include <ATen/core/qualified_name.h>
5 #include <torch/csrc/jit/frontend/parser.h>
6 #include <torch/csrc/jit/frontend/resolver.h>
7 #include <torch/csrc/jit/frontend/script_type_parser.h>
8 #include <torch/custom_class.h>
9
10 #include <regex>
11
12 namespace torch::jit {
13
14 struct OpsValue : public SugaredValue {
OpsValuetorch::jit::OpsValue15 OpsValue(size_t version) : version_(version) {}
kindtorch::jit::OpsValue16 std::string kind() const override {
17 return "ops";
18 }
attrtorch::jit::OpsValue19 std::shared_ptr<SugaredValue> attr(
20 const SourceRange& loc,
21 GraphFunction& m,
22 const std::string& field) override {
23 return std::make_shared<BuiltinModule>(field, version_);
24 }
25 size_t version_;
26 };
27
28 // Represents nested namespaces, like `foo.bar.Baz`.
29 // Right now these namespaces can only contain other namespaces or NamedTypes
30 struct TORCH_API ClassNamespaceValue : public SugaredValue {
31 /**
32 * @param name The fully qualified path, which can resolve either to a
33 * namespace or a NamedType
34 * @param si The source importer that searches for and loads
35 * classes/functions.
36 */
ClassNamespaceValuetorch::jit::ClassNamespaceValue37 explicit ClassNamespaceValue(
38 c10::QualifiedName name,
39 std::shared_ptr<SourceImporterImpl> si)
40 : basename_(std::move(name)), si_(std::move(si)) {}
41
42 std::shared_ptr<SugaredValue> attr(
43 const SourceRange& loc,
44 GraphFunction& m,
45 const std::string& name) override;
kindtorch::jit::ClassNamespaceValue46 std::string kind() const override {
47 return "Class Namespace";
48 }
49
50 private:
51 c10::QualifiedName basename_;
52 std::shared_ptr<SourceImporterImpl> si_;
53 };
54
55 // This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries
56 // in the 'constants' vector. This table is will be stored in a container format
57 // and given to the import_method when restoring the code.
58 struct ConstantTableValue : public SugaredValue {
ConstantTableValuetorch::jit::ConstantTableValue59 explicit ConstantTableValue(const std::vector<at::IValue>* constants)
60 : constants_(constants) {}
kindtorch::jit::ConstantTableValue61 std::string kind() const override {
62 return "CONSTANTS";
63 }
64 // select an attribute on it, e.g. `this.field`
attrtorch::jit::ConstantTableValue65 std::shared_ptr<SugaredValue> attr(
66 const SourceRange& loc,
67 GraphFunction& m,
68 const std::string& field) override {
69 const char* field_s = field.c_str();
70 char* end = nullptr;
71 int64_t offset = strtoll(field_s + 1, &end, 10);
72 if (field.size() < 2 || *end != 0)
73 throw(ErrorReport(loc) << "invalid constant specifier: " << field);
74 if (offset < 0 || size_t(offset) >= constants_->size()) {
75 throw(
76 ErrorReport(loc) << "constant index " << offset
77 << " is out of bounds (constant table has "
78 << constants_->size() << " entries)");
79 }
80 auto ivalue = constants_->at(offset);
81 Value* value = nullptr;
82
83 // see [Constant Object Weak CompilationUnit Reference]
84 if (ivalue.isObject() && !ivalue.toObject()->is_weak_compilation_ref()) {
85 auto obj = ivalue.toObject();
86 if (!non_holding_object_cache.count(obj)) {
87 non_holding_object_cache[obj] = obj->copy_to_weak_compilation_ref();
88 }
89 value = m.graph()->insertConstant(non_holding_object_cache[obj], loc);
90 } else {
91 value = m.graph()->insertConstant(constants_->at(offset), loc);
92 }
93
94 // specializing tensor type on compilation messes up typing relations
95 value->setType(unshapedType(value->type()));
96
97 return std::make_shared<SimpleValue>(value);
98 }
99
100 private:
101 std::unordered_map<
102 c10::intrusive_ptr<at::ivalue::Object>,
103 c10::intrusive_ptr<at::ivalue::Object>>
104 non_holding_object_cache;
105 const std::vector<at::IValue>* constants_;
106 };
107
SourceImporterImpl(std::shared_ptr<CompilationUnit> cu,const std::vector<at::IValue> * constant_table,SourceLoader source_loader,size_t version)108 SourceImporterImpl::SourceImporterImpl(
109 std::shared_ptr<CompilationUnit> cu,
110 const std::vector<at::IValue>* constant_table,
111 SourceLoader source_loader,
112 size_t version)
113 : cu_(std::move(cu)),
114 source_loader_(std::move(source_loader)),
115 version_(version) {
116 env_ = {
117 {"torch", std::make_shared<BuiltinModule>("aten", version)},
118 {"ops", std::make_shared<OpsValue>(version)},
119 // Constants present in the model. Used to resolve "CONSTANTS.n" to the
120 // actual value
121 {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
122 {"fork", SpecialFormValue::create(prim::fork)},
123 {"awaitable", SpecialFormValue::create(prim::awaitable)},
124 {"annotate", SpecialFormValue::create(prim::annotate)},
125 {"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)},
126 {"uninitialized", SpecialFormValue::create(prim::Uninitialized)},
127 };
128 }
129
findNamedType(const QualifiedName & name)130 TypePtr SourceImporterImpl::findNamedType(const QualifiedName& name) {
131 if (auto custom_class = getCustomClass(name.qualifiedName())) {
132 return custom_class;
133 }
134 parseSourceIfNeeded(name.prefix());
135 auto it = to_be_defined_.find(name);
136 if (it != to_be_defined_.end() && it->second->kind() == TK_CLASS_DEF) {
137 ClassDef cd(std::move(it->second));
138 to_be_defined_.erase(it);
139 importNamedType(name.prefix(), cd);
140 }
141 return cu_->get_type(name);
142 }
143
findFunction(const QualifiedName & name)144 Function* SourceImporterImpl::findFunction(const QualifiedName& name) {
145 parseSourceIfNeeded(name.prefix());
146 auto it = to_be_defined_.find(name);
147 if (it != to_be_defined_.end() && it->second->kind() == TK_DEF) {
148 Def d(it->second);
149 to_be_defined_.erase(it);
150 importFunction(name.prefix(), d);
151 }
152 return cu_->find_function(name);
153 }
154
parseSourceIfNeeded(const std::string & qualifier)155 void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) {
156 // qualifier may be blank, for instance checking if __torch__ is a class.
157 if (qualifier.empty() || loaded_sources_.count(qualifier)) {
158 return;
159 }
160 loaded_sources_.insert(qualifier);
161 std::shared_ptr<Source> src = source_loader_(qualifier);
162
163 // The importer, when looking for classes/functions doesn't know if 'foo'
164 // contains definitions or if it is a prefix of 'foo.bar', we only figure it
165 // out by testing if `foo.py` exists in the source loader. If it doesn't
166 // then there is nothing to load here
167 if (!src) {
168 return;
169 }
170 Parser p(src);
171 parsePossibleVersionNumber(p.lexer());
172
173 auto& L = p.lexer();
174
175 while (L.cur().kind != TK_EOF) {
176 parseImports(L);
177 auto tk = L.cur();
178 auto kind = tk.kind;
179 switch (kind) {
180 case TK_CLASS_DEF: {
181 auto parsed_treeref = ClassDef(p.parseClass());
182 to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] =
183 parsed_treeref;
184 } break;
185 case TK_DEF: {
186 auto parsed_treeref = Def(p.parseFunction(/*is_method=*/false));
187 to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] =
188 parsed_treeref;
189 } break;
190 default:
191 throw(
192 ErrorReport(L.cur().range)
193 << "Unexpected token in code import: " << kindToString(kind));
194 }
195 }
196 }
197
LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)198 void SourceImporterImpl::LEGACY_import_methods(
199 const Module& mod,
200 const std::shared_ptr<Source>& src) {
201 auto self = SimpleSelf(mod.type());
202 c10::QualifiedName prefix = *mod.type()->name();
203 Parser p(src);
204
205 parsePossibleVersionNumber(p.lexer());
206
207 parseImports(p.lexer());
208
209 std::vector<Def> definitions;
210 std::vector<ResolverPtr> resolvers;
211 while (p.lexer().cur().kind != TK_EOF) {
212 auto def = Def(p.parseFunction(/*is_method=*/true));
213 definitions.emplace_back(def);
214 resolvers.emplace_back(shared_from_this());
215 }
216 cu_->define(
217 prefix,
218 /*properties=*/{},
219 /*propResolvers=*/{},
220 definitions,
221 resolvers,
222 &self);
223 }
224
resolveValue(const std::string & name,GraphFunction & m,const SourceRange & loc)225 std::shared_ptr<SugaredValue> SourceImporterImpl::resolveValue(
226 const std::string& name,
227 GraphFunction& m,
228 const SourceRange& loc) {
229 auto it = env_.find(name);
230 if (it != env_.end()) {
231 return it->second;
232 }
233 auto graph = m.graph();
234 if (name == "inf") {
235 return std::make_shared<SimpleValue>(
236 graph->insertConstant(std::numeric_limits<double>::infinity(), loc));
237 }
238 if (name == "nan") {
239 return std::make_shared<SimpleValue>(
240 graph->insertConstant(std::numeric_limits<double>::quiet_NaN(), loc));
241 }
242 if (name == "infj") {
243 return std::make_shared<SimpleValue>(graph->insertConstant(
244 c10::complex<double>(0, std::numeric_limits<double>::infinity()), loc));
245 }
246 if (name == "nanj") {
247 return std::make_shared<SimpleValue>(graph->insertConstant(
248 c10::complex<double>(0, std::numeric_limits<double>::quiet_NaN()),
249 loc));
250 }
251 if (name == "__torch__") {
252 return std::make_shared<ClassNamespaceValue>(
253 c10::QualifiedName(name), shared_from_this());
254 }
255 return nullptr;
256 }
257
resolveType(const std::string & name,const SourceRange & loc)258 TypePtr SourceImporterImpl::resolveType(
259 const std::string& name,
260 const SourceRange& loc) {
261 return findNamedType(QualifiedName(name));
262 }
263
importFunction(const std::string & qualifier,const Def & def)264 void SourceImporterImpl::importFunction(
265 const std::string& qualifier,
266 const Def& def) {
267 std::vector<Def> definitions{def};
268 std::vector<ResolverPtr> resolvers{shared_from_this()};
269 cu_->define(
270 qualifier,
271 /*properties=*/{},
272 /*propResolvers=*/{},
273 definitions,
274 resolvers,
275 nullptr);
276 }
277
importNamedType(const std::string & qualifier,const ClassDef & class_def)278 void SourceImporterImpl::importNamedType(
279 const std::string& qualifier,
280 const ClassDef& class_def) {
281 const auto qualified_name =
282 QualifiedName(QualifiedName(qualifier), class_def.name().name());
283 if (!class_def.superclass().present()) {
284 return importClass(qualified_name, class_def, /*is_module=*/false);
285 }
286 const auto& superclass_name = Var(class_def.superclass().get()).name().name();
287 if (superclass_name == "Module") {
288 importClass(qualified_name, class_def, /*is_module=*/true);
289 } else if (superclass_name == "NamedTuple") {
290 // NamedTuples have special rules (since they are TupleTypes and not
291 // ClassTypes)
292 return importNamedTuple(qualified_name, class_def);
293 } else if (superclass_name == "Interface") {
294 cu_->define_interface(
295 qualified_name, class_def, shared_from_this(), /*is_module=*/false);
296 } else if (superclass_name == "ModuleInterface") {
297 cu_->define_interface(
298 qualified_name, class_def, shared_from_this(), /*is_module=*/true);
299 } else if (superclass_name == "Enum") {
300 importEnum(qualified_name, class_def);
301 } else {
302 throw(
303 ErrorReport(class_def.range())
304 << "Torchscript does not support class inheritance.");
305 }
306 }
307
308 std::optional<Assign> SourceImporterImpl::
attributeAssignmentSpecialHandlingHack(const QualifiedName & qualified_classname,const Assign & assign)309 attributeAssignmentSpecialHandlingHack(
310 const QualifiedName& qualified_classname,
311 const Assign& assign) {
312 struct AttrTypeReplacementDescr {
313 std::string attr_name;
314 std::string expected_type;
315 std::string replacement_type;
316 };
317
318 // module demangled qualname -> ReplacementDescr
319 static std::unordered_map<std::string, AttrTypeReplacementDescr> replacements{
320 {"__torch__.torch.ao.nn.quantized.modules.linear.LinearPackedParams",
321 {"_packed_params",
322 "Tensor",
323 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
324 {"__torch__.torch.ao.nn.quantized.modules.linear.Linear",
325 {"_packed_params",
326 "Tensor",
327 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
328 {"__torch__.torch.ao.nn.quantized.dynamic.modules.linear.Linear",
329 {"_packed_params",
330 "Tensor",
331 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
332 {"__torch__.torch.ao.nn.quantized.modules.conv.Conv2d",
333 {"_packed_params",
334 "Tensor",
335 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
336 {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d",
337 {"_packed_params",
338 "Tensor",
339 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
340 {"__torch__.torch.ao.nn.quantized.modules.conv.Conv3d",
341 {"_packed_params",
342 "Tensor",
343 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
344 {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d",
345 {"_packed_params",
346 "Tensor",
347 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
348 // BC Stuff
349 {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams",
350 {"_packed_params",
351 "Tensor",
352 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
353 {"__torch__.torch.nn.quantized.modules.linear.Linear",
354 {"_packed_params",
355 "Tensor",
356 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
357 {"__torch__.torch.nn.quantized.modules.conv.Conv2d",
358 {"_packed_params",
359 "Tensor",
360 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
361 {"__torch__.torch.nn.quantized.modules.conv.Conv3d",
362 {"_packed_params",
363 "Tensor",
364 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
365 {"__torch__.torch.nn.quantized.dynamic.modules.linear.Linear",
366 {"_packed_params",
367 "Tensor",
368 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}};
369 // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
370 static std::regex mangle_re("\\.___torch_mangle_\\d+");
371 auto demangled_classname =
372 std::regex_replace(qualified_classname.qualifiedName(), mangle_re, "");
373 if (replacements.count(demangled_classname)) {
374 auto lhs = Var(assign.lhs());
375 if (!assign.type().present() || assign.type().get().kind() != TK_VAR) {
376 return std::nullopt;
377 }
378 auto type = Var(assign.type().get());
379
380 auto& attr_name = replacements.at(demangled_classname).attr_name;
381 auto& expected_type = replacements.at(demangled_classname).expected_type;
382 auto& replacement_type =
383 replacements.at(demangled_classname).replacement_type;
384 if (lhs.name().name() == attr_name && type.name().name() == expected_type) {
385 Parser p(std::make_shared<Source>(replacement_type));
386 auto typename_expr = p.parseExp();
387 auto maybe_typename =
388 Maybe<Expr>::create(typename_expr.range(), typename_expr);
389 return Assign::create(
390 assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename);
391 }
392 }
393 return std::nullopt;
394 }
395
importClass(const QualifiedName & qualified_classname,const ClassDef & class_def,bool is_module)396 void SourceImporterImpl::importClass(
397 const QualifiedName& qualified_classname,
398 const ClassDef& class_def,
399 bool is_module) {
400 // BC for TorchBind classes
401 //
402 // Previously we would serialize TorchBind classes as actual
403 // classes with methods that delegate to things in the
404 // torch.ops.* namespace. We've switched away from this and
405 // now just rely on those classes being present in the binary
406 // and emit code for them based on the ClassType in memory.
407 //
408 // TODO: remove this once we no longer have old TorchBind code
409 // in production models
410 {
411 static QualifiedName torch_classes_qualname("__torch__.torch.classes");
412 if (torch_classes_qualname.isPrefixOf(qualified_classname)) {
413 return;
414 }
415 }
416 auto class_type = ClassType::create(
417 c10::QualifiedName(qualified_classname), cu_, is_module);
418
419 std::vector<Def> methods;
420 std::vector<ResolverPtr> method_resolvers;
421 std::map<std::string, Def> pre_hook_def_map;
422 std::map<std::string, Def> hook_def_map;
423 std::map<std::string, ResolverPtr> pre_hook_resolver_map;
424 std::map<std::string, ResolverPtr> hook_resolver_map;
425 std::vector<Assign> attributes;
426 std::vector<Assign> constants;
427
428 // Module-specific: which attrs are parameters?
429 std::unordered_set<std::string> parameter_names;
430 std::unordered_set<std::string> buffer_names;
431 std::unordered_set<std::string> pre_hook_names;
432 std::unordered_set<std::string> hook_names;
433 // used to keep track of original ordering of hooks and prehooks
434 // in case any are called more than once
435 std::vector<std::string> pre_hooks_order;
436 std::vector<std::string> hooks_order;
437 // Process statements, splitting things into attribute and method
438 // definitions.
439 for (const auto& statement : class_def.body()) {
440 switch (statement.kind()) {
441 case TK_ASSIGN: {
442 const auto assign = Assign(statement);
443 auto check_assign_values = [&assign](const std::string& name) {
444 TORCH_CHECK(
445 assign.rhs().present(),
446 "Malformed assignment statement: missing values to assign in ",
447 name);
448 };
449 switch (assign.lhs().kind()) {
450 case TK_VAR: {
451 const auto name = Var(assign.lhs()).name().name();
452 if (name == "__parameters__") {
453 // Populate the module parameter list. This is a field that
454 // looks like:
455 // __parameters__ = ["foo", "bar", "baz"]
456 // which tells us which attributes are module parameters.
457 TORCH_INTERNAL_ASSERT(
458 is_module,
459 "Assignments in class body only "
460 "supported on modules right now");
461 check_assign_values(name);
462 const auto param_list = ListLiteral(assign.rhs().get()).inputs();
463 for (const auto& param : param_list) {
464 parameter_names.insert(StringLiteral(param).text());
465 }
466 } else if (name == "__annotations__") {
467 // This is to initialize the annotations dict, just ignore.
468 continue;
469 } else if (name == "__buffers__") {
470 TORCH_INTERNAL_ASSERT(
471 is_module, "Buffers only exist on modules at the moment");
472 check_assign_values(name);
473 const auto buffer_list = ListLiteral(assign.rhs().get()).inputs();
474 for (const auto& buffer : buffer_list) {
475 buffer_names.insert(StringLiteral(buffer).text());
476 }
477 } else if (name == "__forward_pre_hooks__") {
478 TORCH_INTERNAL_ASSERT(
479 is_module,
480 "Forward pre hooks only exist on modules at the moment");
481 check_assign_values(name);
482 const auto pre_hook_list =
483 ListLiteral(assign.rhs().get()).inputs();
484 for (const auto& pre_hook : pre_hook_list) {
485 std::string pre_hook_name = StringLiteral(pre_hook).text();
486 pre_hook_names.insert(pre_hook_name);
487 pre_hooks_order.emplace_back(pre_hook_name);
488 }
489 } else if (name == "__forward_hooks__") {
490 TORCH_INTERNAL_ASSERT(
491 is_module,
492 "Forward hooks only exist on modules at the moment");
493 check_assign_values(name);
494 const auto hook_list = ListLiteral(assign.rhs().get()).inputs();
495 for (const auto& hook : hook_list) {
496 std::string hook_name = StringLiteral(hook).text();
497 hook_names.insert(hook_name);
498 hooks_order.emplace_back(hook_name);
499 }
500 } else {
501 if (auto fixed_up = attributeAssignmentSpecialHandlingHack(
502 qualified_classname, assign)) {
503 attributes.push_back(std::move(*fixed_up));
504 } else if (assign.rhs().present()) {
505 // This is a constant assignment, of the form:
506 // foo : Final[int] = 3
507 constants.push_back(assign);
508 } else {
509 // This is a regular attribute assignment, of the form:
510 // foo : Tensor
511 attributes.push_back(assign);
512 }
513 }
514 } break;
515 case TK_SUBSCRIPT: {
516 // This is a special attribute assignment where the attribute
517 // is not a valid python, identifier. Looks like:
518 // __annotations__["0"] = Tensor
519 const auto lhs = Subscript(assign.lhs());
520 TORCH_INTERNAL_ASSERT(
521 Var(lhs.value()).name().name() == "__annotations__");
522 TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1);
523 attributes.push_back(assign);
524 } break;
525 default: {
526 TORCH_INTERNAL_ASSERT(
527 false,
528 "Unexpected statement kind in module metadata: ",
529 kindToString(statement.kind()));
530 }
531 }
532 } break;
533 case TK_DEF: {
534 Def def = Def(statement);
535 const auto def_name = def.name().name();
536 if (pre_hook_names.find(def_name) != pre_hook_names.end()) {
537 pre_hook_def_map.emplace(def_name, def);
538 pre_hook_resolver_map.emplace(def_name, shared_from_this());
539 } else if (hook_names.find(def_name) != hook_names.end()) {
540 hook_def_map.emplace(def_name, def);
541 hook_resolver_map.emplace(def_name, shared_from_this());
542 } else {
543 methods.emplace_back(def);
544 method_resolvers.push_back(shared_from_this());
545 }
546 } break;
547 default: {
548 TORCH_INTERNAL_ASSERT(
549 false,
550 "Unexpected statement kind in class body: ",
551 kindToString(statement.kind()));
552 }
553 }
554 }
555
556 // Populate class attributes
557 ScriptTypeParser type_parser(shared_from_this());
558 for (const auto& assign : attributes) {
559 // NOLINTNEXTLINE(bugprone-switch-missing-default-case)
560 switch (assign.lhs().kind()) {
561 case TK_VAR: {
562 const auto name = Var(assign.lhs()).name().name();
563 TORCH_INTERNAL_ASSERT(name != "__parameters__");
564 const auto type = assign.type().present()
565 ? type_parser.parseTypeFromExpr(assign.type().get())
566 : type_parser.parseTypeFromExpr(assign.rhs().get());
567 const bool is_parameter = parameter_names.count(name);
568 const bool is_buffer = buffer_names.count(name);
569 class_type->addAttribute(name, type, is_parameter, is_buffer);
570 } break;
571 case TK_SUBSCRIPT: {
572 const auto name =
573 StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0]).text();
574 const auto type = assign.type().present()
575 ? type_parser.parseTypeFromExpr(assign.type().get())
576 : type_parser.parseTypeFromExpr(assign.rhs().get());
577 const bool is_parameter = parameter_names.count(name);
578 const bool is_buffer = buffer_names.count(name);
579 class_type->addAttribute(name, type, is_parameter, is_buffer);
580 }
581 }
582 }
583
584 // Populate class constants
585 for (const auto& assign : constants) {
586 auto const_val = type_parser.parseClassConstant(assign);
587 const auto name = Var(assign.lhs()).name().name();
588 class_type->addConstant(name, const_val);
589 }
590
591 // build pre hook and hook def/resolver pairs
592 // pairs are dedupped in ir_emitter.cpp's CompilationUnit::define_hooks()
593 // ordering here is call order for hooks
594 std::vector<Def> hooks;
595 std::vector<ResolverPtr> hook_resolvers;
596 for (const std::string& hook_name : hooks_order) {
597 hooks.emplace_back(hook_def_map.find(hook_name)->second);
598 hook_resolvers.push_back(hook_resolver_map.find(hook_name)->second);
599 }
600 std::vector<Def> pre_hooks;
601 std::vector<ResolverPtr> pre_hook_resolvers;
602 for (const std::string& pre_hook_name : pre_hooks_order) {
603 pre_hooks.emplace_back(pre_hook_def_map.find(pre_hook_name)->second);
604 pre_hook_resolvers.push_back(
605 pre_hook_resolver_map.find(pre_hook_name)->second);
606 }
607
608 cu_->register_type(class_type);
609 const auto self = SimpleSelf(class_type);
610 // TODO (this will include the version number later)
611 cu_->define(
612 qualified_classname,
613 /*properties=*/{},
614 /*propResolvers=*/{},
615 methods,
616 method_resolvers,
617 &self,
618 /*shouldMangle=*/false,
619 /*operator_set_version=*/version_);
620 cu_->define_hooks(
621 qualified_classname,
622 hooks,
623 hook_resolvers,
624 pre_hooks,
625 pre_hook_resolvers,
626 &self);
627 }
628
importEnum(const QualifiedName & qualified_name,const ClassDef & enum_def)629 void SourceImporterImpl::importEnum(
630 const QualifiedName& qualified_name,
631 const ClassDef& enum_def) {
632 std::vector<at::EnumNameValue> names_values;
633
634 TypePtr value_type = nullptr;
635 auto set_or_check_type =
636 [&value_type](const TypePtr& t, const SourceRange& loc) {
637 if (!value_type) {
638 value_type = t;
639 } else if (value_type != t) {
640 throw(
641 ErrorReport(loc)
642 << "Enum class with varying value types are not supported.");
643 }
644 };
645
646 for (const auto& statement : enum_def.body()) {
647 if (statement.kind() != TK_ASSIGN) {
648 throw(
649 ErrorReport(statement.range())
650 << "Unexpected statement in Enum class body: "
651 "only enum attribute definitions are currently supported.");
652 }
653
654 const auto assign = Assign(statement);
655 const auto name = Var(assign.lhs()).name().name();
656
657 IValue ivalue;
658 auto rhs = assign.rhs().get();
659 switch (rhs.kind()) {
660 case TK_STRINGLITERAL:
661 ivalue = IValue(StringLiteral(rhs).text());
662 set_or_check_type(StringType::get(), statement.range());
663 break;
664 case TK_CONST: {
665 auto numeric_const = Const(rhs);
666 if (numeric_const.isFloatingPoint()) {
667 ivalue = IValue(numeric_const.asFloatingPoint());
668 set_or_check_type(FloatType::get(), statement.range());
669 } else if (numeric_const.isIntegral()) {
670 ivalue = IValue(numeric_const.asIntegral());
671 set_or_check_type(IntType::get(), statement.range());
672 }
673 break;
674 }
675 default:
676 throw(
677 ErrorReport(rhs.range())
678 << "Unsupported enum value type: " << rhs.kind()
679 << ". Only Integers, Floats and Strings are supported.");
680 }
681
682 names_values.emplace_back(name, ivalue);
683 }
684
685 if (!value_type) {
686 throw(
687 ErrorReport(enum_def.range())
688 << "No enum values defined for " << qualified_name.qualifiedName());
689 }
690
691 auto enum_type = EnumType::create(
692 qualified_name, std::move(value_type), std::move(names_values), cu_);
693 cu_->register_type(enum_type);
694 }
695
importNamedTuple(const QualifiedName & qualified_name,const ClassDef & named_tuple_def)696 void SourceImporterImpl::importNamedTuple(
697 const QualifiedName& qualified_name,
698 const ClassDef& named_tuple_def) {
699 ScriptTypeParser type_parser(shared_from_this());
700 std::vector<std::string> field_names;
701 std::vector<TypePtr> field_types;
702 std::vector<IValue> field_defaults;
703 for (const auto& statement : named_tuple_def.body()) {
704 if (statement.kind() != TK_ASSIGN) {
705 throw(
706 ErrorReport(statement.range())
707 << "Unexpected statement in NamedTuple body: "
708 "only attribute annotations are currently supported.");
709 }
710 const auto assign = Assign(statement);
711 TORCH_INTERNAL_ASSERT(assign.type().present());
712
713 auto name = Var(Assign(statement).lhs()).name().name();
714 std::optional<IValue> default_val;
715 if (assign.rhs().present()) {
716 std::vector<IValue> parsed = type_parser.evaluateDefaults(
717 assign.rhs().range(), {assign.rhs().get()}, {assign.type().get()});
718 TORCH_INTERNAL_ASSERT(parsed.size() == 1);
719 default_val = parsed[0];
720 }
721
722 auto type = type_parser.parseTypeFromExpr(assign.type().get());
723
724 field_names.emplace_back(std::move(name));
725 field_types.emplace_back(std::move(type));
726 if (default_val) {
727 field_defaults.emplace_back(std::move(*default_val));
728 }
729 }
730
731 auto tt = TupleType::createNamed(
732 qualified_name, field_names, field_types, field_defaults);
733 cu_->register_type(tt);
734 }
735
parsePossibleVersionNumber(Lexer & L)736 void SourceImporterImpl::parsePossibleVersionNumber(Lexer& L) {
737 // Older versions of serialization produced an op_version_set string
738 // per-file We now just use a single version which is handled by
739 // PyTorchStreamReader. We used to check if op_version_set was _newer_ for
740 // forward compatibility reasons but now that it doesn't exist there can't
741 // be a newer one, so we just discard this.
742 if (L.cur().kind == TK_IDENT && L.cur().text() == "op_version_set") {
743 auto range = L.cur().range;
744 L.next();
745 L.expect('=');
746 L.expect(TK_NUMBER);
747 L.expect(TK_NEWLINE);
748 }
749 }
750
751 // older versions of serialization required import statements,
752 // and defined classes file-at-a-time in import order.
753 // The problem is that in Python
754 // it is possible to construct cyclic dependencies between files even
755 // when there are none between individual classes. New versions of loading
756 // just compile class-at-a-time, so we no longer need to follow the import
757 // order. Future serialization may stop producing the import code.
parseImports(Lexer & L)758 void SourceImporterImpl::parseImports(Lexer& L) {
759 while (L.nextIf(TK_IMPORT)) {
760 std::ostringstream s;
761 while (L.cur().kind != TK_NEWLINE) {
762 s << L.cur().text();
763 L.next();
764 }
765 L.expect(TK_NEWLINE);
766 }
767 }
768
attr(const SourceRange & loc,GraphFunction & m,const std::string & name)769 std::shared_ptr<SugaredValue> ClassNamespaceValue::attr(
770 const SourceRange& loc,
771 GraphFunction& m,
772 const std::string& name) {
773 auto fullName = c10::QualifiedName(basename_, name);
774 // Could be a ClassType or NamedTuple constructor
775 if (auto serializable_type = si_->findNamedType(fullName)) {
776 if (auto classType = serializable_type->cast<ClassType>()) {
777 return std::make_shared<ClassValue>(classType);
778 } else if (auto tupleType = serializable_type->cast<TupleType>()) {
779 return std::make_shared<NamedTupleConstructor>(tupleType);
780 } else if (auto enumType = serializable_type->cast<EnumType>()) {
781 return std::make_shared<SugaredEnumClass>(enumType);
782 }
783 }
784
785 // Or it could be a free function
786 if (auto fn = si_->findFunction(fullName)) {
787 return std::make_shared<FunctionValue>(fn);
788 }
789
790 // If it's none of those things, assume it's another namespace
791 return std::make_shared<ClassNamespaceValue>(std::move(fullName), si_);
792 }
793
SourceImporter(std::shared_ptr<CompilationUnit> cu,const std::vector<IValue> * constant_table,SourceLoader loader,size_t version)794 SourceImporter::SourceImporter(
795 // The compilation unit that will own the imported source
796 std::shared_ptr<CompilationUnit> cu,
797 const std::vector<IValue>* constant_table,
798 SourceLoader loader,
799 size_t version)
800 : pImpl(std::make_shared<SourceImporterImpl>(
801 std::move(cu),
802 constant_table,
803 std::move(loader),
804 version)) {}
805
loadType(const QualifiedName & name) const806 TypePtr SourceImporter::loadType(const QualifiedName& name) const {
807 ScriptTypeParser type_parser(pImpl);
808 TypePtr t = type_parser.parseType(name.qualifiedName());
809 return t;
810 }
811
LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)812 void SourceImporter::LEGACY_import_methods(
813 const Module& mod,
814 const std::shared_ptr<Source>& src) {
815 pImpl->LEGACY_import_methods(mod, src);
816 }
817 SourceImporter::~SourceImporter() = default;
818
819 } // namespace torch::jit
820