xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/script_type_parser.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/jit_type.h>
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/frontend/resolver.h>
5 #include <torch/csrc/jit/frontend/tree_views.h>
6 
7 namespace torch::jit {
8 
9 /**
10  * class ScriptTypeParser
11  *
12  * Parses expressions in our typed AST format (TreeView) into types and
13  * typenames.
14  */
15 class TORCH_API ScriptTypeParser {
16  public:
17   explicit ScriptTypeParser() = default;
ScriptTypeParser(ResolverPtr resolver)18   explicit ScriptTypeParser(ResolverPtr resolver)
19       : resolver_(std::move(resolver)) {}
20 
21   c10::TypePtr parseTypeFromExpr(const Expr& expr) const;
22 
23   std::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(
24       const Expr& expr) const;
25 
26   c10::TypePtr parseType(const std::string& str);
27 
28   FunctionSchema parseSchemaFromDef(const Def& def, bool skip_self);
29 
30   c10::IValue parseClassConstant(const Assign& assign);
31 
32  private:
33   c10::TypePtr parseTypeFromExprImpl(const Expr& expr) const;
34 
35   std::optional<std::string> parseBaseTypeName(const Expr& expr) const;
36   at::TypePtr subscriptToType(
37       const std::string& typeName,
38       const Subscript& subscript) const;
39   std::vector<IValue> evaluateDefaults(
40       const SourceRange& r,
41       const std::vector<Expr>& default_types,
42       const std::vector<Expr>& default_exprs);
43   std::vector<Argument> parseArgsFromDecl(const Decl& decl, bool skip_self);
44 
45   std::vector<Argument> parseReturnFromDecl(const Decl& decl);
46 
47   ResolverPtr resolver_ = nullptr;
48 
49   // Need to use `evaluateDefaults` in serialization
50   friend struct ConstantTableValue;
51   friend struct SourceImporterImpl;
52 };
53 } // namespace torch::jit
54