1 #pragma once 2 #include <ATen/core/jit_type.h> 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/frontend/resolver.h> 5 #include <torch/csrc/jit/frontend/tree_views.h> 6 7 namespace torch::jit { 8 9 /** 10 * class ScriptTypeParser 11 * 12 * Parses expressions in our typed AST format (TreeView) into types and 13 * typenames. 14 */ 15 class TORCH_API ScriptTypeParser { 16 public: 17 explicit ScriptTypeParser() = default; ScriptTypeParser(ResolverPtr resolver)18 explicit ScriptTypeParser(ResolverPtr resolver) 19 : resolver_(std::move(resolver)) {} 20 21 c10::TypePtr parseTypeFromExpr(const Expr& expr) const; 22 23 std::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList( 24 const Expr& expr) const; 25 26 c10::TypePtr parseType(const std::string& str); 27 28 FunctionSchema parseSchemaFromDef(const Def& def, bool skip_self); 29 30 c10::IValue parseClassConstant(const Assign& assign); 31 32 private: 33 c10::TypePtr parseTypeFromExprImpl(const Expr& expr) const; 34 35 std::optional<std::string> parseBaseTypeName(const Expr& expr) const; 36 at::TypePtr subscriptToType( 37 const std::string& typeName, 38 const Subscript& subscript) const; 39 std::vector<IValue> evaluateDefaults( 40 const SourceRange& r, 41 const std::vector<Expr>& default_types, 42 const std::vector<Expr>& default_exprs); 43 std::vector<Argument> parseArgsFromDecl(const Decl& decl, bool skip_self); 44 45 std::vector<Argument> parseReturnFromDecl(const Decl& decl); 46 47 ResolverPtr resolver_ = nullptr; 48 49 // Need to use `evaluateDefaults` in serialization 50 friend struct ConstantTableValue; 51 friend struct SourceImporterImpl; 52 }; 53 } // namespace torch::jit 54