1 #include <torch/csrc/jit/backends/backend_resolver.h> 2 #include <torch/csrc/jit/frontend/sugared_value.h> 3 #include <torch/custom_class.h> 4 5 namespace torch { 6 namespace jit { 7 namespace { 8 // Essentially ClassNamespaceValue from import_source.cpp without the 9 // SourceImporterImpl reference. This helps resolve the 10 // __torch__.torch.classes.backends.{backend_name} symbols in the generated code 11 // for the LoweredModule. 12 struct ClassNamespaceValue : public SugaredValue { ClassNamespaceValuetorch::jit::__anon033ecd720111::ClassNamespaceValue13 explicit ClassNamespaceValue(c10::QualifiedName name) 14 : basename_(std::move(name)) {} 15 attrtorch::jit::__anon033ecd720111::ClassNamespaceValue16 std::shared_ptr<SugaredValue> attr( 17 const SourceRange& loc, 18 GraphFunction& m, 19 const std::string& name) override { 20 auto fullName = c10::QualifiedName(basename_, name); 21 22 // Check to see if it is a custom class. 23 if (auto custom_class = getCustomClass(fullName.qualifiedName())) { 24 return std::make_shared<ClassValue>(custom_class); 25 } 26 27 // If it's not a custom class, assume it's another namespace 28 return std::make_shared<ClassNamespaceValue>(std::move(fullName)); 29 } 30 kindtorch::jit::__anon033ecd720111::ClassNamespaceValue31 std::string kind() const override { 32 return "Class Namespace"; 33 } 34 35 private: 36 c10::QualifiedName basename_; 37 }; 38 39 // A resolver just for resolving custom backend class lookups in the 40 // LoweredModule classes generated by the rest of the cdoe in this file. 41 struct LoweredModuleResolver : public Resolver { resolveValuetorch::jit::__anon033ecd720111::LoweredModuleResolver42 std::shared_ptr<SugaredValue> resolveValue( 43 const std::string& name, 44 GraphFunction& m, 45 const SourceRange& loc) override { 46 if (name == "torch") { 47 return std::make_shared<BuiltinModule>("aten"); 48 } else if (name == "__torch__") { 49 return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name)); 50 } else if (name == "Exception") { 51 return std::make_shared<ExceptionValue>(name); 52 } 53 54 return nullptr; 55 } 56 resolveTypetorch::jit::__anon033ecd720111::LoweredModuleResolver57 TypePtr resolveType(const std::string& name, const SourceRange& loc) 58 override { 59 return nullptr; 60 } 61 }; 62 } // namespace 63 loweredModuleResolver()64std::shared_ptr<Resolver> loweredModuleResolver() { 65 std::shared_ptr<Resolver> resolver = 66 std::make_shared<LoweredModuleResolver>(); 67 return resolver; 68 } 69 70 } // namespace jit 71 } // namespace torch 72