xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/backend_resolver.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/backends/backend_resolver.h>
2 #include <torch/csrc/jit/frontend/sugared_value.h>
3 #include <torch/custom_class.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace {
8 // Essentially ClassNamespaceValue from import_source.cpp without the
9 // SourceImporterImpl reference. This helps resolve the
10 // __torch__.torch.classes.backends.{backend_name} symbols in the generated code
11 // for the LoweredModule.
12 struct ClassNamespaceValue : public SugaredValue {
ClassNamespaceValuetorch::jit::__anon033ecd720111::ClassNamespaceValue13   explicit ClassNamespaceValue(c10::QualifiedName name)
14       : basename_(std::move(name)) {}
15 
attrtorch::jit::__anon033ecd720111::ClassNamespaceValue16   std::shared_ptr<SugaredValue> attr(
17       const SourceRange& loc,
18       GraphFunction& m,
19       const std::string& name) override {
20     auto fullName = c10::QualifiedName(basename_, name);
21 
22     // Check to see if it is a custom class.
23     if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
24       return std::make_shared<ClassValue>(custom_class);
25     }
26 
27     // If it's not a custom class, assume it's another namespace
28     return std::make_shared<ClassNamespaceValue>(std::move(fullName));
29   }
30 
kindtorch::jit::__anon033ecd720111::ClassNamespaceValue31   std::string kind() const override {
32     return "Class Namespace";
33   }
34 
35  private:
36   c10::QualifiedName basename_;
37 };
38 
39 // A resolver just for resolving custom backend class lookups in the
40 // LoweredModule classes generated by the rest of the cdoe in this file.
41 struct LoweredModuleResolver : public Resolver {
resolveValuetorch::jit::__anon033ecd720111::LoweredModuleResolver42   std::shared_ptr<SugaredValue> resolveValue(
43       const std::string& name,
44       GraphFunction& m,
45       const SourceRange& loc) override {
46     if (name == "torch") {
47       return std::make_shared<BuiltinModule>("aten");
48     } else if (name == "__torch__") {
49       return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
50     } else if (name == "Exception") {
51       return std::make_shared<ExceptionValue>(name);
52     }
53 
54     return nullptr;
55   }
56 
resolveTypetorch::jit::__anon033ecd720111::LoweredModuleResolver57   TypePtr resolveType(const std::string& name, const SourceRange& loc)
58       override {
59     return nullptr;
60   }
61 };
62 } // namespace
63 
loweredModuleResolver()64 std::shared_ptr<Resolver> loweredModuleResolver() {
65   std::shared_ptr<Resolver> resolver =
66       std::make_shared<LoweredModuleResolver>();
67   return resolver;
68 }
69 
70 } // namespace jit
71 } // namespace torch
72