xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/script_type_parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/script_type_parser.h>
2 
3 #include <ATen/core/type_factory.h>
4 #include <torch/csrc/jit/frontend/parser.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/custom_class.h>
7 
8 namespace torch::jit {
9 namespace {
10 
isTorch(const Expr & expr)11 bool isTorch(const Expr& expr) {
12   return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
13 }
14 
collectQualname(const Select & select)15 std::string collectQualname(const Select& select) {
16   Expr base = select.value();
17   if (base.kind() == TK_VAR) {
18     return Var(base).name().name() + "." + select.selector().name();
19   }
20   std::string basename = collectQualname(Select(base));
21   return basename + "." + select.selector().name();
22 }
23 
string_to_type_lut()24 const std::unordered_map<std::string, c10::TypePtr>& string_to_type_lut() {
25   return c10::DefaultTypeFactory::basePythonTypes();
26 }
27 
28 } // namespace
29 
subscriptToType(const std::string & typeName,const Subscript & subscript) const30 TypePtr ScriptTypeParser::subscriptToType(
31     const std::string& typeName,
32     const Subscript& subscript) const {
33   if (typeName == "Tuple" || typeName == "tuple") {
34     if (subscript.subscript_exprs().size() == 1 &&
35         subscript.subscript_exprs()[0].kind() == TK_TUPLE_LITERAL) {
36       // `typing.Tuple` special cases syntax for empty tuple annotations,
37       // i.e. `typing.Tuple[()]`. Allow for parsing an empty tuple literal
38       // here. See https://docs.python.org/3/library/typing.html#typing.Tuple
39       auto tup_literal = TupleLiteral(subscript.subscript_exprs()[0]);
40       if (!tup_literal.inputs().empty()) {
41         throw(
42             ErrorReport(tup_literal.range())
43             << "Tuple literal in Tuple type annotation must not "
44             << "have any elements!");
45       }
46       return TupleType::create({});
47     }
48     std::vector<TypePtr> subscript_expr_types;
49     for (auto expr : subscript.subscript_exprs()) {
50       subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
51     }
52     return TupleType::create(subscript_expr_types);
53   } else if (typeName == "List" || typeName == "list") {
54     if (subscript.subscript_exprs().size() != 1) {
55       throw ErrorReport(subscript)
56           << " expected exactly one element type but found "
57           << subscript.subscript_exprs().size();
58     }
59     auto elem_type =
60         parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
61     return ListType::create(elem_type);
62 
63   } else if (typeName == "Optional") {
64     if (subscript.subscript_exprs().size() != 1) {
65       throw ErrorReport(subscript)
66           << " expected exactly one element type but found "
67           << subscript.subscript_exprs().size();
68     }
69     auto elem_type =
70         parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
71     return OptionalType::create(elem_type);
72 
73   } else if (typeName == "Union") {
74     std::vector<TypePtr> subscript_expr_types;
75     subscript_expr_types.reserve(subscript.subscript_exprs().size());
76     for (auto expr : subscript.subscript_exprs()) {
77       subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
78     }
79     return UnionType::create(subscript_expr_types);
80   } else if (typeName == "Future" || typeName == "torch.jit.Future") {
81     if (subscript.subscript_exprs().size() != 1) {
82       throw ErrorReport(subscript)
83           << " expected exactly one element type but found "
84           << subscript.subscript_exprs().size();
85     }
86     auto elem_type =
87         parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
88     return FutureType::create(elem_type);
89   } else if (typeName == "Await" || typeName == "torch.jit._Await") {
90     if (subscript.subscript_exprs().size() != 1) {
91       throw ErrorReport(subscript)
92           << " expected exactly one element type but found "
93           << subscript.subscript_exprs().size();
94     }
95     auto elem_type =
96         parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
97     return AwaitType::create(elem_type);
98   } else if (typeName == "RRef") {
99     if (subscript.subscript_exprs().size() != 1) {
100       throw ErrorReport(subscript)
101           << " expected exactly one element type but found "
102           << subscript.subscript_exprs().size();
103     }
104     auto elem_type =
105         parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
106     return RRefType::create(elem_type);
107   } else if (typeName == "Dict" || typeName == "dict") {
108     if (subscript.subscript_exprs().size() != 2) {
109       throw ErrorReport(subscript)
110           << " expected exactly 2 element types but found "
111           << subscript.subscript_exprs().size();
112     }
113     auto key_type = parseTypeFromExprImpl(subscript.subscript_exprs()[0]);
114     auto value_type = parseTypeFromExprImpl(subscript.subscript_exprs()[1]);
115     return DictType::create(key_type, value_type);
116   } else {
117     throw ErrorReport(subscript.range())
118         << "Unknown type constructor " << typeName;
119   }
120 }
121 
parseBroadcastList(const Expr & expr) const122 std::optional<std::pair<TypePtr, int32_t>> ScriptTypeParser::parseBroadcastList(
123     const Expr& expr) const {
124   // Alias torch.nn._common_types._size_?_t to BroadcastingList?[int]
125   if (expr.kind() == TK_VAR) {
126     auto var = Var(expr);
127     auto& name = var.name().name();
128     constexpr auto _size_prefix = "_size_";
129     constexpr auto _size_suffix = "_t";
130     constexpr auto _size_n_len = 9; // strlen("_size_X_t")
131     constexpr auto _size_prefix_len = 6; // strlen("_size_");
132     if (name.find(_size_prefix) == 0 && name.length() == _size_n_len &&
133         name.find(_size_suffix) == _size_prefix_len + 1 &&
134         ::isdigit(name[_size_prefix_len])) {
135       int n = name[_size_prefix_len] - '0';
136       return std::pair<TypePtr, int32_t>(ListType::create(IntType::get()), n);
137     }
138   }
139 
140   if (expr.kind() != TK_SUBSCRIPT)
141     return std::nullopt;
142   auto subscript = Subscript(expr);
143   if (subscript.value().kind() != TK_VAR)
144     return std::nullopt;
145   auto var = Var(subscript.value());
146   auto subscript_exprs = subscript.subscript_exprs();
147 
148   // handle the case where the BroadcastingList is wrapped in a Optional type
149   if (var.name().name() == "Optional") {
150     auto broadcast_list = parseBroadcastList(subscript_exprs[0]);
151     if (broadcast_list) {
152       TypePtr opt_type = OptionalType::create(broadcast_list->first);
153       return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
154     } else {
155       return std::nullopt;
156     }
157   } else if (var.name().name().find("BroadcastingList") != 0) {
158     return std::nullopt;
159   }
160 
161   if (subscript_exprs.size() != 1)
162     throw ErrorReport(subscript.subscript_exprs().range())
163         << "BroadcastingList/Optional[BroadcastingList] "
164            "must be subscripted with a type";
165 
166   auto typ = subscript_exprs[0];
167   auto len = var.name().name().substr(strlen("BroadcastingList"));
168 
169   if (typ.kind() != TK_VAR)
170     throw ErrorReport(subscript.value().range())
171         << "Subscripted type must be a type identifier";
172 
173   auto value_name = Var(typ).name().name();
174   if (value_name != "float" && value_name != "int")
175     throw ErrorReport(subscript.value().range())
176         << "Broadcastable lists only supported for int or float";
177 
178   auto elem_ptr = string_to_type_lut().find(value_name);
179   AT_ASSERT(elem_ptr != string_to_type_lut().end());
180   TypePtr list_ptr = ListType::create(elem_ptr->second);
181 
182   const char* len_c = len.c_str();
183   char* end = nullptr;
184   size_t len_v = strtoull(len_c, &end, 10);
185   if (end != len_c + len.size()) {
186     throw(
187         ErrorReport(subscript.subscript_exprs().range())
188         << "subscript of Broadcastable list must be a positive integer");
189   }
190   return std::pair<TypePtr, int32_t>(list_ptr, len_v);
191 }
192 
193 // gets the base type name given namespaces where the types live
194 // turns torch.Tensor -> Tensor, X -> X
parseBaseTypeName(const Expr & expr) const195 std::optional<std::string> ScriptTypeParser::parseBaseTypeName(
196     const Expr& expr) const {
197   switch (expr.kind()) {
198     case TK_VAR: {
199       return Var(expr).name().name();
200     }
201     case TK_NONE: {
202       return "None";
203     }
204     case TK_NONE_TYPE: {
205       return "NoneType";
206     }
207     case '.': {
208       auto select = Select(expr);
209       const std::string& name = select.selector().name();
210       // Special case for torch.Tensor and its' subclasses
211       const std::unordered_set<std::string> tensor_subtypes = {
212           "Tensor",
213           "LongTensor",
214           "FloatTensor",
215           "DoubleTensor",
216           "IntTensor",
217           "ShortTensor",
218           "HalfTensor",
219           "CharTensor",
220           "ByteTensor",
221           "BoolTensor"};
222       if (isTorch(select.value()) && tensor_subtypes.count(name) == 1) {
223         return name;
224       } else {
225         // Otherwise, it's a fully qualified class name
226         return collectQualname(select);
227       }
228     } break;
229   }
230   return std::nullopt;
231 }
232 
parseTypeFromExpr(const Expr & expr) const233 TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
234   // the resolver needs to recursively resolve the expression, so to avoid
235   // resolving all type expr subtrees we only use it for the top level
236   // expression and base type names.
237   if (expr.kind() == '|') {
238     auto converted = pep604union_to_union(expr);
239     return parseTypeFromExpr(converted);
240   }
241   if (resolver_) {
242     if (auto typePtr =
243             resolver_->resolveType(expr.range().text().str(), expr.range())) {
244       return typePtr;
245     }
246   }
247   return parseTypeFromExprImpl(expr);
248 }
249 
parseTypeFromExprImpl(const Expr & expr) const250 TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const {
251   if (expr.kind() == '|') {
252     auto converted = pep604union_to_union(expr);
253     return parseTypeFromExprImpl(converted);
254   }
255   if (expr.kind() == TK_SUBSCRIPT) {
256     auto subscript = Subscript(expr);
257     auto value_name = parseBaseTypeName(subscript.value());
258     if (!value_name) {
259       throw ErrorReport(subscript.value().range())
260           << "Subscripted type must be a type identifier";
261     }
262     return subscriptToType(*value_name, subscript);
263 
264   } else if (expr.kind() == TK_STRINGLITERAL) {
265     const auto& type_name = StringLiteral(expr).text();
266 
267     // Check if the type is a custom class. This is done by checking
268     // if type_name starts with "torch.classes."
269     if (type_name.find("torch.classes.") == 0) {
270       auto custom_class_type = getCustomClass("__torch__." + type_name);
271       return custom_class_type;
272     }
273 
274     // `torch.cuda.Stream` and `torch.cuda.Event` are aliased as
275     // custom classes of type torch.classes.cuda.Stream and
276     // torch.classes.cuda.Event respectively. Return the respective
277     // custom class types for these two cases.
278     if (type_name.find("torch.cuda.Stream") == 0) {
279       auto custom_class_type =
280           getCustomClass("__torch__.torch.classes.cuda.Stream");
281       return custom_class_type;
282     }
283 
284     if (type_name.find("torch.cuda.Event") == 0) {
285       auto custom_class_type =
286           getCustomClass("__torch__.torch.classes.cuda.Event");
287       return custom_class_type;
288     }
289 
290     if (resolver_) {
291       if (auto typePtr = resolver_->resolveType(type_name, expr.range())) {
292         return typePtr;
293       }
294     }
295 
296     throw ErrorReport(expr) << "Unknown type name '" << type_name << "'";
297   } else if (auto name = parseBaseTypeName(expr)) {
298     auto itr = string_to_type_lut().find(*name);
299     if (itr != string_to_type_lut().end()) {
300       return itr->second;
301     }
302     if (resolver_) {
303       if (auto typePtr = resolver_->resolveType(*name, expr.range())) {
304         return typePtr;
305       }
306     }
307 
308     if (auto custom_class_type = getCustomClass(*name)) {
309       return custom_class_type;
310     }
311 
312     throw ErrorReport(expr) << "Unknown type name '" << *name << "'";
313   }
314   throw ErrorReport(expr.range())
315       << "Expression of type " << kindToString(expr.kind())
316       << " cannot be used in a type expression";
317 }
318 
parseType(const std::string & str)319 TypePtr ScriptTypeParser::parseType(const std::string& str) {
320   Parser p(std::make_shared<Source>(str));
321   return parseTypeFromExpr(p.parseExp());
322 }
323 
evaluateDefaults(const SourceRange & r,const std::vector<Expr> & default_types,const std::vector<Expr> & default_exprs)324 std::vector<IValue> ScriptTypeParser::evaluateDefaults(
325     const SourceRange& r,
326     const std::vector<Expr>& default_types,
327     const std::vector<Expr>& default_exprs) {
328   std::vector<IValue> default_values;
329   if (default_exprs.empty())
330     return default_values;
331   // To evaluate the default expressions, we create a graph with no inputs,
332   // and whose returns are the default values we need.
333   // We then run constant prop on this graph and check the results are
334   // constant. This approach avoids having to have separate handling of
335   // default arguments from standard expressions by piecing together existing
336   // machinery for graph generation, constant propagation, and constant
337   // extraction.
338   auto tuple_type = Subscript::create(
339       r,
340       Var::create(r, Ident::create(r, "Tuple")),
341       List<Expr>::create(r, default_types));
342   auto blank_decl = Decl::create(
343       r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
344 
345   auto tuple_expr =
346       TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
347   auto ret = Return::create(r, tuple_expr);
348   auto def = Def::create(
349       r,
350       Ident::create(r, "defaults"),
351       blank_decl,
352       List<Stmt>::create(r, {ret}));
353 
354   CompilationUnit cu;
355   cu.define(
356       std::nullopt,
357       /*properties=*/{},
358       /*propResolvers=*/{},
359       {def},
360       {resolver_},
361       nullptr);
362   Stack stack;
363   // XXX: We need to turn optimization off here because otherwise we try to
364   // recursively initialize stuff in DecomposeOps.
365   GraphOptimizerEnabledGuard guard(false);
366   auto& f = cu.get_function(def.name().name());
367   auto* gf = dynamic_cast<GraphFunction*>(&f);
368   TORCH_INTERNAL_ASSERT(gf);
369   // 2024.08.14: Since we are starting to deprecate Torchscript usages,
370   // we are going to log all the calls for GraphFunction::run. The logging was
371   // noisy we also call GraphFunction::run for the default value evaluation
372   // which generates a lot of useless log samples. Therefore as a workaround we
373   // just directly use the executor API which avoids this placing producing
374   // un-necessary log entries.
375   gf->get_executor().run(stack);
376   return stack.at(0).toTupleRef().elements().vec();
377 }
378 
parseArgsFromDecl(const Decl & decl,bool skip_self)379 std::vector<Argument> ScriptTypeParser::parseArgsFromDecl(
380     const Decl& decl,
381     bool skip_self) {
382   auto params_begin = decl.params().begin();
383   auto params_end = decl.params().end();
384   if (skip_self) {
385     ++params_begin;
386   }
387   std::vector<Argument> retval;
388 
389   std::vector<Expr> default_types;
390   std::vector<Expr> default_exprs;
391   // gather any non-empty default arguments
392   for (auto it = params_begin; it != params_end; ++it) {
393     auto param = *it;
394     auto def = param.defaultValue();
395     if (def.present()) {
396       if (!param.type().present()) {
397         // We require explicit type-hints for default expressions.
398         // If param doesn't have a type, we could default to "Tensor",
399         // just like what happens in the Python frontend.
400         // However here things are a bit more complicated, because
401         // default expressions are evaluated using a custom-built
402         // graph, and error messages coming out of that in case
403         // the type doesn't match the value are quite obscure.
404         throw ErrorReport(param.range())
405             << "Keyword arguments with defaults need to be type-hinted (TorchScript C++ frontend)";
406       }
407       default_types.emplace_back(param.type().get());
408       default_exprs.emplace_back(def.get());
409     }
410   }
411 
412   auto default_values =
413       evaluateDefaults(decl.range(), default_types, default_exprs);
414 
415   auto defaults_it = default_values.begin();
416   for (auto it = params_begin; it != params_end; ++it) {
417     auto decl_arg = *it;
418 
419     TypePtr type;
420     std::optional<int32_t> N = std::nullopt;
421     if (!decl_arg.type().present()) {
422       // If this param doesn't have a type, default to "tensor"
423       type = TensorType::getInferred();
424     } else {
425       // BroadcastList list can only appear at the argument level
426       Expr type_expr = decl_arg.type().get();
427       if (auto maybe_broad_list = parseBroadcastList(type_expr)) {
428         type = maybe_broad_list->first;
429         N = maybe_broad_list->second;
430       } else {
431         type = parseTypeFromExpr(decl_arg.type().get());
432       }
433     }
434     std::optional<IValue> default_value = std::nullopt;
435     if (decl_arg.defaultValue().present()) {
436       default_value = *defaults_it++;
437     }
438     auto arg = Argument(
439         decl_arg.ident().name(),
440         type,
441         N,
442         default_value,
443         decl_arg.kwarg_only(),
444         /*alias_info=*/std::nullopt);
445     retval.push_back(arg);
446   }
447   return retval;
448 }
449 
parseReturnFromDecl(const Decl & decl)450 std::vector<Argument> ScriptTypeParser::parseReturnFromDecl(const Decl& decl) {
451   // we represent no annoation on a return type as having no values in the
452   // schema's return() list
453   // in emitReturn we take the actual return value to be the value of the
454   // return statement if no one was provided here
455   if (!decl.return_type().present())
456     return {};
457 
458   if (parseBroadcastList(decl.return_type().get()))
459     throw ErrorReport(decl.return_type().range())
460         << "Broadcastable lists cannot appear as a return type";
461 
462   TypePtr parsed_type;
463   Expr type_expr = decl.return_type().get();
464   parsed_type = parseTypeFromExpr(type_expr);
465   return {Argument(
466       "",
467       parsed_type,
468       /*N =*/std::nullopt,
469       /*default_value =*/std::nullopt,
470       /*kwarg_only =*/false)};
471 }
parseSchemaFromDef(const Def & def,bool skip_self)472 FunctionSchema ScriptTypeParser::parseSchemaFromDef(
473     const Def& def,
474     bool skip_self) {
475   const auto name = def.name().name();
476   std::vector<Argument> args = parseArgsFromDecl(def.decl(), skip_self);
477   std::vector<Argument> returns = parseReturnFromDecl(def.decl());
478   return FunctionSchema(
479       name, "", std::move(args), std::move(returns), false, false);
480 }
481 
parseClassConstant(const Assign & assign)482 c10::IValue ScriptTypeParser::parseClassConstant(const Assign& assign) {
483   if (assign.lhs().kind() != TK_VAR) {
484     throw ErrorReport(assign.range())
485         << "Expected to a variable for class constant";
486   }
487   if (!assign.type().present()) {
488     throw ErrorReport(assign.range())
489         << "Expected a type to present for class constant";
490   }
491   const auto final_type = assign.type().get();
492   auto expr = assign.rhs().get();
493   if (final_type.kind() != TK_SUBSCRIPT) {
494     throw ErrorReport(assign.range())
495         << "Expected subscripted type for class constant";
496   }
497   auto subscript = Subscript(final_type);
498   auto value_name = parseBaseTypeName(subscript.value());
499   if (!value_name) {
500     throw ErrorReport(subscript.value().range())
501         << "Subscripted type must be a type identifier";
502   }
503   if (*value_name != "Final") {
504     throw ErrorReport(subscript.range())
505         << "Base type must be Final for class constant";
506   }
507   if (subscript.subscript_exprs().size() != 1) {
508     throw ErrorReport(subscript)
509         << " expected exactly one element type but found "
510         << subscript.subscript_exprs().size();
511   }
512   auto type = *subscript.subscript_exprs().begin();
513   auto default_val = evaluateDefaults(expr.range(), {type}, {expr});
514   return *default_val.begin();
515 }
516 
517 } // namespace torch::jit
518