xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/import_source.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/import_source.h>
2 
3 #include <ATen/core/ivalue_inl.h>
4 #include <ATen/core/qualified_name.h>
5 #include <torch/csrc/jit/frontend/parser.h>
6 #include <torch/csrc/jit/frontend/resolver.h>
7 #include <torch/csrc/jit/frontend/script_type_parser.h>
8 #include <torch/custom_class.h>
9 
10 #include <regex>
11 
12 namespace torch::jit {
13 
14 struct OpsValue : public SugaredValue {
OpsValuetorch::jit::OpsValue15   OpsValue(size_t version) : version_(version) {}
kindtorch::jit::OpsValue16   std::string kind() const override {
17     return "ops";
18   }
attrtorch::jit::OpsValue19   std::shared_ptr<SugaredValue> attr(
20       const SourceRange& loc,
21       GraphFunction& m,
22       const std::string& field) override {
23     return std::make_shared<BuiltinModule>(field, version_);
24   }
25   size_t version_;
26 };
27 
28 // Represents nested namespaces, like `foo.bar.Baz`.
29 // Right now these namespaces can only contain other namespaces or NamedTypes
30 struct TORCH_API ClassNamespaceValue : public SugaredValue {
31   /**
32    * @param  name  The fully qualified path, which can resolve either to a
33    *               namespace or a NamedType
34    * @param  si    The source importer that searches for and loads
35    * classes/functions.
36    */
ClassNamespaceValuetorch::jit::ClassNamespaceValue37   explicit ClassNamespaceValue(
38       c10::QualifiedName name,
39       std::shared_ptr<SourceImporterImpl> si)
40       : basename_(std::move(name)), si_(std::move(si)) {}
41 
42   std::shared_ptr<SugaredValue> attr(
43       const SourceRange& loc,
44       GraphFunction& m,
45       const std::string& name) override;
kindtorch::jit::ClassNamespaceValue46   std::string kind() const override {
47     return "Class Namespace";
48   }
49 
50  private:
51   c10::QualifiedName basename_;
52   std::shared_ptr<SourceImporterImpl> si_;
53 };
54 
55 // This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries
56 // in the 'constants' vector. This table is will be stored in a container format
57 // and given to the import_method when restoring the code.
58 struct ConstantTableValue : public SugaredValue {
ConstantTableValuetorch::jit::ConstantTableValue59   explicit ConstantTableValue(const std::vector<at::IValue>* constants)
60       : constants_(constants) {}
kindtorch::jit::ConstantTableValue61   std::string kind() const override {
62     return "CONSTANTS";
63   }
64   // select an attribute on it, e.g. `this.field`
attrtorch::jit::ConstantTableValue65   std::shared_ptr<SugaredValue> attr(
66       const SourceRange& loc,
67       GraphFunction& m,
68       const std::string& field) override {
69     const char* field_s = field.c_str();
70     char* end = nullptr;
71     int64_t offset = strtoll(field_s + 1, &end, 10);
72     if (field.size() < 2 || *end != 0)
73       throw(ErrorReport(loc) << "invalid constant specifier: " << field);
74     if (offset < 0 || size_t(offset) >= constants_->size()) {
75       throw(
76           ErrorReport(loc) << "constant index " << offset
77                            << " is out of bounds (constant table has "
78                            << constants_->size() << " entries)");
79     }
80     auto ivalue = constants_->at(offset);
81     Value* value = nullptr;
82 
83     // see [Constant Object Weak CompilationUnit Reference]
84     if (ivalue.isObject() && !ivalue.toObject()->is_weak_compilation_ref()) {
85       auto obj = ivalue.toObject();
86       if (!non_holding_object_cache.count(obj)) {
87         non_holding_object_cache[obj] = obj->copy_to_weak_compilation_ref();
88       }
89       value = m.graph()->insertConstant(non_holding_object_cache[obj], loc);
90     } else {
91       value = m.graph()->insertConstant(constants_->at(offset), loc);
92     }
93 
94     // specializing tensor type on compilation messes up typing relations
95     value->setType(unshapedType(value->type()));
96 
97     return std::make_shared<SimpleValue>(value);
98   }
99 
100  private:
101   std::unordered_map<
102       c10::intrusive_ptr<at::ivalue::Object>,
103       c10::intrusive_ptr<at::ivalue::Object>>
104       non_holding_object_cache;
105   const std::vector<at::IValue>* constants_;
106 };
107 
SourceImporterImpl(std::shared_ptr<CompilationUnit> cu,const std::vector<at::IValue> * constant_table,SourceLoader source_loader,size_t version)108 SourceImporterImpl::SourceImporterImpl(
109     std::shared_ptr<CompilationUnit> cu,
110     const std::vector<at::IValue>* constant_table,
111     SourceLoader source_loader,
112     size_t version)
113     : cu_(std::move(cu)),
114       source_loader_(std::move(source_loader)),
115       version_(version) {
116   env_ = {
117       {"torch", std::make_shared<BuiltinModule>("aten", version)},
118       {"ops", std::make_shared<OpsValue>(version)},
119       // Constants present in the model. Used to resolve "CONSTANTS.n" to the
120       // actual value
121       {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
122       {"fork", SpecialFormValue::create(prim::fork)},
123       {"awaitable", SpecialFormValue::create(prim::awaitable)},
124       {"annotate", SpecialFormValue::create(prim::annotate)},
125       {"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)},
126       {"uninitialized", SpecialFormValue::create(prim::Uninitialized)},
127   };
128 }
129 
findNamedType(const QualifiedName & name)130 TypePtr SourceImporterImpl::findNamedType(const QualifiedName& name) {
131   if (auto custom_class = getCustomClass(name.qualifiedName())) {
132     return custom_class;
133   }
134   parseSourceIfNeeded(name.prefix());
135   auto it = to_be_defined_.find(name);
136   if (it != to_be_defined_.end() && it->second->kind() == TK_CLASS_DEF) {
137     ClassDef cd(std::move(it->second));
138     to_be_defined_.erase(it);
139     importNamedType(name.prefix(), cd);
140   }
141   return cu_->get_type(name);
142 }
143 
findFunction(const QualifiedName & name)144 Function* SourceImporterImpl::findFunction(const QualifiedName& name) {
145   parseSourceIfNeeded(name.prefix());
146   auto it = to_be_defined_.find(name);
147   if (it != to_be_defined_.end() && it->second->kind() == TK_DEF) {
148     Def d(it->second);
149     to_be_defined_.erase(it);
150     importFunction(name.prefix(), d);
151   }
152   return cu_->find_function(name);
153 }
154 
parseSourceIfNeeded(const std::string & qualifier)155 void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) {
156   // qualifier may be blank, for instance checking if __torch__ is a class.
157   if (qualifier.empty() || loaded_sources_.count(qualifier)) {
158     return;
159   }
160   loaded_sources_.insert(qualifier);
161   std::shared_ptr<Source> src = source_loader_(qualifier);
162 
163   // The importer, when looking for classes/functions doesn't know if 'foo'
164   // contains definitions or if it is a prefix of 'foo.bar', we only figure it
165   // out by testing if `foo.py` exists in the source loader. If it doesn't
166   // then there is nothing to load here
167   if (!src) {
168     return;
169   }
170   Parser p(src);
171   parsePossibleVersionNumber(p.lexer());
172 
173   auto& L = p.lexer();
174 
175   while (L.cur().kind != TK_EOF) {
176     parseImports(L);
177     auto tk = L.cur();
178     auto kind = tk.kind;
179     switch (kind) {
180       case TK_CLASS_DEF: {
181         auto parsed_treeref = ClassDef(p.parseClass());
182         to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] =
183             parsed_treeref;
184       } break;
185       case TK_DEF: {
186         auto parsed_treeref = Def(p.parseFunction(/*is_method=*/false));
187         to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] =
188             parsed_treeref;
189       } break;
190       default:
191         throw(
192             ErrorReport(L.cur().range)
193             << "Unexpected token in code import: " << kindToString(kind));
194     }
195   }
196 }
197 
LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)198 void SourceImporterImpl::LEGACY_import_methods(
199     const Module& mod,
200     const std::shared_ptr<Source>& src) {
201   auto self = SimpleSelf(mod.type());
202   c10::QualifiedName prefix = *mod.type()->name();
203   Parser p(src);
204 
205   parsePossibleVersionNumber(p.lexer());
206 
207   parseImports(p.lexer());
208 
209   std::vector<Def> definitions;
210   std::vector<ResolverPtr> resolvers;
211   while (p.lexer().cur().kind != TK_EOF) {
212     auto def = Def(p.parseFunction(/*is_method=*/true));
213     definitions.emplace_back(def);
214     resolvers.emplace_back(shared_from_this());
215   }
216   cu_->define(
217       prefix,
218       /*properties=*/{},
219       /*propResolvers=*/{},
220       definitions,
221       resolvers,
222       &self);
223 }
224 
resolveValue(const std::string & name,GraphFunction & m,const SourceRange & loc)225 std::shared_ptr<SugaredValue> SourceImporterImpl::resolveValue(
226     const std::string& name,
227     GraphFunction& m,
228     const SourceRange& loc) {
229   auto it = env_.find(name);
230   if (it != env_.end()) {
231     return it->second;
232   }
233   auto graph = m.graph();
234   if (name == "inf") {
235     return std::make_shared<SimpleValue>(
236         graph->insertConstant(std::numeric_limits<double>::infinity(), loc));
237   }
238   if (name == "nan") {
239     return std::make_shared<SimpleValue>(
240         graph->insertConstant(std::numeric_limits<double>::quiet_NaN(), loc));
241   }
242   if (name == "infj") {
243     return std::make_shared<SimpleValue>(graph->insertConstant(
244         c10::complex<double>(0, std::numeric_limits<double>::infinity()), loc));
245   }
246   if (name == "nanj") {
247     return std::make_shared<SimpleValue>(graph->insertConstant(
248         c10::complex<double>(0, std::numeric_limits<double>::quiet_NaN()),
249         loc));
250   }
251   if (name == "__torch__") {
252     return std::make_shared<ClassNamespaceValue>(
253         c10::QualifiedName(name), shared_from_this());
254   }
255   return nullptr;
256 }
257 
resolveType(const std::string & name,const SourceRange & loc)258 TypePtr SourceImporterImpl::resolveType(
259     const std::string& name,
260     const SourceRange& loc) {
261   return findNamedType(QualifiedName(name));
262 }
263 
importFunction(const std::string & qualifier,const Def & def)264 void SourceImporterImpl::importFunction(
265     const std::string& qualifier,
266     const Def& def) {
267   std::vector<Def> definitions{def};
268   std::vector<ResolverPtr> resolvers{shared_from_this()};
269   cu_->define(
270       qualifier,
271       /*properties=*/{},
272       /*propResolvers=*/{},
273       definitions,
274       resolvers,
275       nullptr);
276 }
277 
importNamedType(const std::string & qualifier,const ClassDef & class_def)278 void SourceImporterImpl::importNamedType(
279     const std::string& qualifier,
280     const ClassDef& class_def) {
281   const auto qualified_name =
282       QualifiedName(QualifiedName(qualifier), class_def.name().name());
283   if (!class_def.superclass().present()) {
284     return importClass(qualified_name, class_def, /*is_module=*/false);
285   }
286   const auto& superclass_name = Var(class_def.superclass().get()).name().name();
287   if (superclass_name == "Module") {
288     importClass(qualified_name, class_def, /*is_module=*/true);
289   } else if (superclass_name == "NamedTuple") {
290     // NamedTuples have special rules (since they are TupleTypes and not
291     // ClassTypes)
292     return importNamedTuple(qualified_name, class_def);
293   } else if (superclass_name == "Interface") {
294     cu_->define_interface(
295         qualified_name, class_def, shared_from_this(), /*is_module=*/false);
296   } else if (superclass_name == "ModuleInterface") {
297     cu_->define_interface(
298         qualified_name, class_def, shared_from_this(), /*is_module=*/true);
299   } else if (superclass_name == "Enum") {
300     importEnum(qualified_name, class_def);
301   } else {
302     throw(
303         ErrorReport(class_def.range())
304         << "Torchscript does not support class inheritance.");
305   }
306 }
307 
308 std::optional<Assign> SourceImporterImpl::
attributeAssignmentSpecialHandlingHack(const QualifiedName & qualified_classname,const Assign & assign)309     attributeAssignmentSpecialHandlingHack(
310         const QualifiedName& qualified_classname,
311         const Assign& assign) {
312   struct AttrTypeReplacementDescr {
313     std::string attr_name;
314     std::string expected_type;
315     std::string replacement_type;
316   };
317 
318   // module demangled qualname -> ReplacementDescr
319   static std::unordered_map<std::string, AttrTypeReplacementDescr> replacements{
320       {"__torch__.torch.ao.nn.quantized.modules.linear.LinearPackedParams",
321        {"_packed_params",
322         "Tensor",
323         "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
324       {"__torch__.torch.ao.nn.quantized.modules.linear.Linear",
325        {"_packed_params",
326         "Tensor",
327         "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
328       {"__torch__.torch.ao.nn.quantized.dynamic.modules.linear.Linear",
329        {"_packed_params",
330         "Tensor",
331         "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
332       {"__torch__.torch.ao.nn.quantized.modules.conv.Conv2d",
333        {"_packed_params",
334         "Tensor",
335         "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
336       {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d",
337        {"_packed_params",
338         "Tensor",
339         "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
340       {"__torch__.torch.ao.nn.quantized.modules.conv.Conv3d",
341        {"_packed_params",
342         "Tensor",
343         "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
344       {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d",
345        {"_packed_params",
346         "Tensor",
347         "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
348       // BC Stuff
349       {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams",
350        {"_packed_params",
351         "Tensor",
352         "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
353       {"__torch__.torch.nn.quantized.modules.linear.Linear",
354        {"_packed_params",
355         "Tensor",
356         "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
357       {"__torch__.torch.nn.quantized.modules.conv.Conv2d",
358        {"_packed_params",
359         "Tensor",
360         "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
361       {"__torch__.torch.nn.quantized.modules.conv.Conv3d",
362        {"_packed_params",
363         "Tensor",
364         "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
365       {"__torch__.torch.nn.quantized.dynamic.modules.linear.Linear",
366        {"_packed_params",
367         "Tensor",
368         "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}};
369   // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
370   static std::regex mangle_re("\\.___torch_mangle_\\d+");
371   auto demangled_classname =
372       std::regex_replace(qualified_classname.qualifiedName(), mangle_re, "");
373   if (replacements.count(demangled_classname)) {
374     auto lhs = Var(assign.lhs());
375     if (!assign.type().present() || assign.type().get().kind() != TK_VAR) {
376       return std::nullopt;
377     }
378     auto type = Var(assign.type().get());
379 
380     auto& attr_name = replacements.at(demangled_classname).attr_name;
381     auto& expected_type = replacements.at(demangled_classname).expected_type;
382     auto& replacement_type =
383         replacements.at(demangled_classname).replacement_type;
384     if (lhs.name().name() == attr_name && type.name().name() == expected_type) {
385       Parser p(std::make_shared<Source>(replacement_type));
386       auto typename_expr = p.parseExp();
387       auto maybe_typename =
388           Maybe<Expr>::create(typename_expr.range(), typename_expr);
389       return Assign::create(
390           assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename);
391     }
392   }
393   return std::nullopt;
394 }
395 
importClass(const QualifiedName & qualified_classname,const ClassDef & class_def,bool is_module)396 void SourceImporterImpl::importClass(
397     const QualifiedName& qualified_classname,
398     const ClassDef& class_def,
399     bool is_module) {
400   // BC for TorchBind classes
401   //
402   // Previously we would serialize TorchBind classes as actual
403   // classes with methods that delegate to things in the
404   // torch.ops.* namespace. We've switched away from this and
405   // now just rely on those classes being present in the binary
406   // and emit code for them based on the ClassType in memory.
407   //
408   // TODO: remove this once we no longer have old TorchBind code
409   // in production models
410   {
411     static QualifiedName torch_classes_qualname("__torch__.torch.classes");
412     if (torch_classes_qualname.isPrefixOf(qualified_classname)) {
413       return;
414     }
415   }
416   auto class_type = ClassType::create(
417       c10::QualifiedName(qualified_classname), cu_, is_module);
418 
419   std::vector<Def> methods;
420   std::vector<ResolverPtr> method_resolvers;
421   std::map<std::string, Def> pre_hook_def_map;
422   std::map<std::string, Def> hook_def_map;
423   std::map<std::string, ResolverPtr> pre_hook_resolver_map;
424   std::map<std::string, ResolverPtr> hook_resolver_map;
425   std::vector<Assign> attributes;
426   std::vector<Assign> constants;
427 
428   // Module-specific: which attrs are parameters?
429   std::unordered_set<std::string> parameter_names;
430   std::unordered_set<std::string> buffer_names;
431   std::unordered_set<std::string> pre_hook_names;
432   std::unordered_set<std::string> hook_names;
433   // used to keep track of original ordering of hooks and prehooks
434   // in case any are called more than once
435   std::vector<std::string> pre_hooks_order;
436   std::vector<std::string> hooks_order;
437   // Process statements, splitting things into attribute and method
438   // definitions.
439   for (const auto& statement : class_def.body()) {
440     switch (statement.kind()) {
441       case TK_ASSIGN: {
442         const auto assign = Assign(statement);
443         auto check_assign_values = [&assign](const std::string& name) {
444           TORCH_CHECK(
445               assign.rhs().present(),
446               "Malformed assignment statement: missing values to assign in ",
447               name);
448         };
449         switch (assign.lhs().kind()) {
450           case TK_VAR: {
451             const auto name = Var(assign.lhs()).name().name();
452             if (name == "__parameters__") {
453               // Populate the module parameter list. This is a field that
454               // looks like:
455               //   __parameters__ = ["foo", "bar", "baz"]
456               // which tells us which attributes are module parameters.
457               TORCH_INTERNAL_ASSERT(
458                   is_module,
459                   "Assignments in class body only "
460                   "supported on modules right now");
461               check_assign_values(name);
462               const auto param_list = ListLiteral(assign.rhs().get()).inputs();
463               for (const auto& param : param_list) {
464                 parameter_names.insert(StringLiteral(param).text());
465               }
466             } else if (name == "__annotations__") {
467               // This is to initialize the annotations dict, just ignore.
468               continue;
469             } else if (name == "__buffers__") {
470               TORCH_INTERNAL_ASSERT(
471                   is_module, "Buffers only exist on modules at the moment");
472               check_assign_values(name);
473               const auto buffer_list = ListLiteral(assign.rhs().get()).inputs();
474               for (const auto& buffer : buffer_list) {
475                 buffer_names.insert(StringLiteral(buffer).text());
476               }
477             } else if (name == "__forward_pre_hooks__") {
478               TORCH_INTERNAL_ASSERT(
479                   is_module,
480                   "Forward pre hooks only exist on modules at the moment");
481               check_assign_values(name);
482               const auto pre_hook_list =
483                   ListLiteral(assign.rhs().get()).inputs();
484               for (const auto& pre_hook : pre_hook_list) {
485                 std::string pre_hook_name = StringLiteral(pre_hook).text();
486                 pre_hook_names.insert(pre_hook_name);
487                 pre_hooks_order.emplace_back(pre_hook_name);
488               }
489             } else if (name == "__forward_hooks__") {
490               TORCH_INTERNAL_ASSERT(
491                   is_module,
492                   "Forward hooks only exist on modules at the moment");
493               check_assign_values(name);
494               const auto hook_list = ListLiteral(assign.rhs().get()).inputs();
495               for (const auto& hook : hook_list) {
496                 std::string hook_name = StringLiteral(hook).text();
497                 hook_names.insert(hook_name);
498                 hooks_order.emplace_back(hook_name);
499               }
500             } else {
501               if (auto fixed_up = attributeAssignmentSpecialHandlingHack(
502                       qualified_classname, assign)) {
503                 attributes.push_back(std::move(*fixed_up));
504               } else if (assign.rhs().present()) {
505                 // This is a constant assignment, of the form:
506                 // foo : Final[int] = 3
507                 constants.push_back(assign);
508               } else {
509                 // This is a regular attribute assignment, of the form:
510                 // foo : Tensor
511                 attributes.push_back(assign);
512               }
513             }
514           } break;
515           case TK_SUBSCRIPT: {
516             // This is a special attribute assignment where the attribute
517             // is not a valid python, identifier. Looks like:
518             //    __annotations__["0"] = Tensor
519             const auto lhs = Subscript(assign.lhs());
520             TORCH_INTERNAL_ASSERT(
521                 Var(lhs.value()).name().name() == "__annotations__");
522             TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1);
523             attributes.push_back(assign);
524           } break;
525           default: {
526             TORCH_INTERNAL_ASSERT(
527                 false,
528                 "Unexpected statement kind in module metadata: ",
529                 kindToString(statement.kind()));
530           }
531         }
532       } break;
533       case TK_DEF: {
534         Def def = Def(statement);
535         const auto def_name = def.name().name();
536         if (pre_hook_names.find(def_name) != pre_hook_names.end()) {
537           pre_hook_def_map.emplace(def_name, def);
538           pre_hook_resolver_map.emplace(def_name, shared_from_this());
539         } else if (hook_names.find(def_name) != hook_names.end()) {
540           hook_def_map.emplace(def_name, def);
541           hook_resolver_map.emplace(def_name, shared_from_this());
542         } else {
543           methods.emplace_back(def);
544           method_resolvers.push_back(shared_from_this());
545         }
546       } break;
547       default: {
548         TORCH_INTERNAL_ASSERT(
549             false,
550             "Unexpected statement kind in class body: ",
551             kindToString(statement.kind()));
552       }
553     }
554   }
555 
556   // Populate class attributes
557   ScriptTypeParser type_parser(shared_from_this());
558   for (const auto& assign : attributes) {
559     // NOLINTNEXTLINE(bugprone-switch-missing-default-case)
560     switch (assign.lhs().kind()) {
561       case TK_VAR: {
562         const auto name = Var(assign.lhs()).name().name();
563         TORCH_INTERNAL_ASSERT(name != "__parameters__");
564         const auto type = assign.type().present()
565             ? type_parser.parseTypeFromExpr(assign.type().get())
566             : type_parser.parseTypeFromExpr(assign.rhs().get());
567         const bool is_parameter = parameter_names.count(name);
568         const bool is_buffer = buffer_names.count(name);
569         class_type->addAttribute(name, type, is_parameter, is_buffer);
570       } break;
571       case TK_SUBSCRIPT: {
572         const auto name =
573             StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0]).text();
574         const auto type = assign.type().present()
575             ? type_parser.parseTypeFromExpr(assign.type().get())
576             : type_parser.parseTypeFromExpr(assign.rhs().get());
577         const bool is_parameter = parameter_names.count(name);
578         const bool is_buffer = buffer_names.count(name);
579         class_type->addAttribute(name, type, is_parameter, is_buffer);
580       }
581     }
582   }
583 
584   // Populate class constants
585   for (const auto& assign : constants) {
586     auto const_val = type_parser.parseClassConstant(assign);
587     const auto name = Var(assign.lhs()).name().name();
588     class_type->addConstant(name, const_val);
589   }
590 
591   // build pre hook and hook def/resolver pairs
592   // pairs are dedupped in ir_emitter.cpp's CompilationUnit::define_hooks()
593   // ordering here is call order for hooks
594   std::vector<Def> hooks;
595   std::vector<ResolverPtr> hook_resolvers;
596   for (const std::string& hook_name : hooks_order) {
597     hooks.emplace_back(hook_def_map.find(hook_name)->second);
598     hook_resolvers.push_back(hook_resolver_map.find(hook_name)->second);
599   }
600   std::vector<Def> pre_hooks;
601   std::vector<ResolverPtr> pre_hook_resolvers;
602   for (const std::string& pre_hook_name : pre_hooks_order) {
603     pre_hooks.emplace_back(pre_hook_def_map.find(pre_hook_name)->second);
604     pre_hook_resolvers.push_back(
605         pre_hook_resolver_map.find(pre_hook_name)->second);
606   }
607 
608   cu_->register_type(class_type);
609   const auto self = SimpleSelf(class_type);
610   // TODO (this will include the version number later)
611   cu_->define(
612       qualified_classname,
613       /*properties=*/{},
614       /*propResolvers=*/{},
615       methods,
616       method_resolvers,
617       &self,
618       /*shouldMangle=*/false,
619       /*operator_set_version=*/version_);
620   cu_->define_hooks(
621       qualified_classname,
622       hooks,
623       hook_resolvers,
624       pre_hooks,
625       pre_hook_resolvers,
626       &self);
627 }
628 
importEnum(const QualifiedName & qualified_name,const ClassDef & enum_def)629 void SourceImporterImpl::importEnum(
630     const QualifiedName& qualified_name,
631     const ClassDef& enum_def) {
632   std::vector<at::EnumNameValue> names_values;
633 
634   TypePtr value_type = nullptr;
635   auto set_or_check_type =
636       [&value_type](const TypePtr& t, const SourceRange& loc) {
637         if (!value_type) {
638           value_type = t;
639         } else if (value_type != t) {
640           throw(
641               ErrorReport(loc)
642               << "Enum class with varying value types are not supported.");
643         }
644       };
645 
646   for (const auto& statement : enum_def.body()) {
647     if (statement.kind() != TK_ASSIGN) {
648       throw(
649           ErrorReport(statement.range())
650           << "Unexpected statement in Enum class body: "
651              "only enum attribute definitions are currently supported.");
652     }
653 
654     const auto assign = Assign(statement);
655     const auto name = Var(assign.lhs()).name().name();
656 
657     IValue ivalue;
658     auto rhs = assign.rhs().get();
659     switch (rhs.kind()) {
660       case TK_STRINGLITERAL:
661         ivalue = IValue(StringLiteral(rhs).text());
662         set_or_check_type(StringType::get(), statement.range());
663         break;
664       case TK_CONST: {
665         auto numeric_const = Const(rhs);
666         if (numeric_const.isFloatingPoint()) {
667           ivalue = IValue(numeric_const.asFloatingPoint());
668           set_or_check_type(FloatType::get(), statement.range());
669         } else if (numeric_const.isIntegral()) {
670           ivalue = IValue(numeric_const.asIntegral());
671           set_or_check_type(IntType::get(), statement.range());
672         }
673         break;
674       }
675       default:
676         throw(
677             ErrorReport(rhs.range())
678             << "Unsupported enum value type: " << rhs.kind()
679             << ". Only Integers, Floats and Strings are supported.");
680     }
681 
682     names_values.emplace_back(name, ivalue);
683   }
684 
685   if (!value_type) {
686     throw(
687         ErrorReport(enum_def.range())
688         << "No enum values defined for " << qualified_name.qualifiedName());
689   }
690 
691   auto enum_type = EnumType::create(
692       qualified_name, std::move(value_type), std::move(names_values), cu_);
693   cu_->register_type(enum_type);
694 }
695 
importNamedTuple(const QualifiedName & qualified_name,const ClassDef & named_tuple_def)696 void SourceImporterImpl::importNamedTuple(
697     const QualifiedName& qualified_name,
698     const ClassDef& named_tuple_def) {
699   ScriptTypeParser type_parser(shared_from_this());
700   std::vector<std::string> field_names;
701   std::vector<TypePtr> field_types;
702   std::vector<IValue> field_defaults;
703   for (const auto& statement : named_tuple_def.body()) {
704     if (statement.kind() != TK_ASSIGN) {
705       throw(
706           ErrorReport(statement.range())
707           << "Unexpected statement in NamedTuple body: "
708              "only attribute annotations are currently supported.");
709     }
710     const auto assign = Assign(statement);
711     TORCH_INTERNAL_ASSERT(assign.type().present());
712 
713     auto name = Var(Assign(statement).lhs()).name().name();
714     std::optional<IValue> default_val;
715     if (assign.rhs().present()) {
716       std::vector<IValue> parsed = type_parser.evaluateDefaults(
717           assign.rhs().range(), {assign.rhs().get()}, {assign.type().get()});
718       TORCH_INTERNAL_ASSERT(parsed.size() == 1);
719       default_val = parsed[0];
720     }
721 
722     auto type = type_parser.parseTypeFromExpr(assign.type().get());
723 
724     field_names.emplace_back(std::move(name));
725     field_types.emplace_back(std::move(type));
726     if (default_val) {
727       field_defaults.emplace_back(std::move(*default_val));
728     }
729   }
730 
731   auto tt = TupleType::createNamed(
732       qualified_name, field_names, field_types, field_defaults);
733   cu_->register_type(tt);
734 }
735 
parsePossibleVersionNumber(Lexer & L)736 void SourceImporterImpl::parsePossibleVersionNumber(Lexer& L) {
737   // Older versions of serialization produced an op_version_set string
738   // per-file We now just use a single version which is handled by
739   // PyTorchStreamReader. We used to check if op_version_set was _newer_ for
740   // forward compatibility reasons but now that it doesn't exist there can't
741   // be a newer one, so we just discard this.
742   if (L.cur().kind == TK_IDENT && L.cur().text() == "op_version_set") {
743     auto range = L.cur().range;
744     L.next();
745     L.expect('=');
746     L.expect(TK_NUMBER);
747     L.expect(TK_NEWLINE);
748   }
749 }
750 
751 // older versions of serialization required import statements,
752 // and defined classes file-at-a-time in import order.
753 // The problem is that in Python
754 // it is possible to construct cyclic dependencies between files even
755 // when there are none between individual classes. New versions of loading
756 // just compile class-at-a-time, so we no longer need to follow the import
757 // order. Future serialization may stop producing the import code.
parseImports(Lexer & L)758 void SourceImporterImpl::parseImports(Lexer& L) {
759   while (L.nextIf(TK_IMPORT)) {
760     std::ostringstream s;
761     while (L.cur().kind != TK_NEWLINE) {
762       s << L.cur().text();
763       L.next();
764     }
765     L.expect(TK_NEWLINE);
766   }
767 }
768 
attr(const SourceRange & loc,GraphFunction & m,const std::string & name)769 std::shared_ptr<SugaredValue> ClassNamespaceValue::attr(
770     const SourceRange& loc,
771     GraphFunction& m,
772     const std::string& name) {
773   auto fullName = c10::QualifiedName(basename_, name);
774   // Could be a ClassType or NamedTuple constructor
775   if (auto serializable_type = si_->findNamedType(fullName)) {
776     if (auto classType = serializable_type->cast<ClassType>()) {
777       return std::make_shared<ClassValue>(classType);
778     } else if (auto tupleType = serializable_type->cast<TupleType>()) {
779       return std::make_shared<NamedTupleConstructor>(tupleType);
780     } else if (auto enumType = serializable_type->cast<EnumType>()) {
781       return std::make_shared<SugaredEnumClass>(enumType);
782     }
783   }
784 
785   // Or it could be a free function
786   if (auto fn = si_->findFunction(fullName)) {
787     return std::make_shared<FunctionValue>(fn);
788   }
789 
790   // If it's none of those things, assume it's another namespace
791   return std::make_shared<ClassNamespaceValue>(std::move(fullName), si_);
792 }
793 
SourceImporter(std::shared_ptr<CompilationUnit> cu,const std::vector<IValue> * constant_table,SourceLoader loader,size_t version)794 SourceImporter::SourceImporter(
795     // The compilation unit that will own the imported source
796     std::shared_ptr<CompilationUnit> cu,
797     const std::vector<IValue>* constant_table,
798     SourceLoader loader,
799     size_t version)
800     : pImpl(std::make_shared<SourceImporterImpl>(
801           std::move(cu),
802           constant_table,
803           std::move(loader),
804           version)) {}
805 
loadType(const QualifiedName & name) const806 TypePtr SourceImporter::loadType(const QualifiedName& name) const {
807   ScriptTypeParser type_parser(pImpl);
808   TypePtr t = type_parser.parseType(name.qualifiedName());
809   return t;
810 }
811 
LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)812 void SourceImporter::LEGACY_import_methods(
813     const Module& mod,
814     const std::shared_ptr<Source>& src) {
815   pImpl->LEGACY_import_methods(mod, src);
816 }
817 SourceImporter::~SourceImporter() = default;
818 
819 } // namespace torch::jit
820