1 #include <torch/csrc/jit/frontend/script_type_parser.h>
2
3 #include <ATen/core/type_factory.h>
4 #include <torch/csrc/jit/frontend/parser.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/custom_class.h>
7
8 namespace torch::jit {
9 namespace {
10
isTorch(const Expr & expr)11 bool isTorch(const Expr& expr) {
12 return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
13 }
14
collectQualname(const Select & select)15 std::string collectQualname(const Select& select) {
16 Expr base = select.value();
17 if (base.kind() == TK_VAR) {
18 return Var(base).name().name() + "." + select.selector().name();
19 }
20 std::string basename = collectQualname(Select(base));
21 return basename + "." + select.selector().name();
22 }
23
string_to_type_lut()24 const std::unordered_map<std::string, c10::TypePtr>& string_to_type_lut() {
25 return c10::DefaultTypeFactory::basePythonTypes();
26 }
27
28 } // namespace
29
subscriptToType(const std::string & typeName,const Subscript & subscript) const30 TypePtr ScriptTypeParser::subscriptToType(
31 const std::string& typeName,
32 const Subscript& subscript) const {
33 if (typeName == "Tuple" || typeName == "tuple") {
34 if (subscript.subscript_exprs().size() == 1 &&
35 subscript.subscript_exprs()[0].kind() == TK_TUPLE_LITERAL) {
36 // `typing.Tuple` special cases syntax for empty tuple annotations,
37 // i.e. `typing.Tuple[()]`. Allow for parsing an empty tuple literal
38 // here. See https://docs.python.org/3/library/typing.html#typing.Tuple
39 auto tup_literal = TupleLiteral(subscript.subscript_exprs()[0]);
40 if (!tup_literal.inputs().empty()) {
41 throw(
42 ErrorReport(tup_literal.range())
43 << "Tuple literal in Tuple type annotation must not "
44 << "have any elements!");
45 }
46 return TupleType::create({});
47 }
48 std::vector<TypePtr> subscript_expr_types;
49 for (auto expr : subscript.subscript_exprs()) {
50 subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
51 }
52 return TupleType::create(subscript_expr_types);
53 } else if (typeName == "List" || typeName == "list") {
54 if (subscript.subscript_exprs().size() != 1) {
55 throw ErrorReport(subscript)
56 << " expected exactly one element type but found "
57 << subscript.subscript_exprs().size();
58 }
59 auto elem_type =
60 parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
61 return ListType::create(elem_type);
62
63 } else if (typeName == "Optional") {
64 if (subscript.subscript_exprs().size() != 1) {
65 throw ErrorReport(subscript)
66 << " expected exactly one element type but found "
67 << subscript.subscript_exprs().size();
68 }
69 auto elem_type =
70 parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
71 return OptionalType::create(elem_type);
72
73 } else if (typeName == "Union") {
74 std::vector<TypePtr> subscript_expr_types;
75 subscript_expr_types.reserve(subscript.subscript_exprs().size());
76 for (auto expr : subscript.subscript_exprs()) {
77 subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
78 }
79 return UnionType::create(subscript_expr_types);
80 } else if (typeName == "Future" || typeName == "torch.jit.Future") {
81 if (subscript.subscript_exprs().size() != 1) {
82 throw ErrorReport(subscript)
83 << " expected exactly one element type but found "
84 << subscript.subscript_exprs().size();
85 }
86 auto elem_type =
87 parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
88 return FutureType::create(elem_type);
89 } else if (typeName == "Await" || typeName == "torch.jit._Await") {
90 if (subscript.subscript_exprs().size() != 1) {
91 throw ErrorReport(subscript)
92 << " expected exactly one element type but found "
93 << subscript.subscript_exprs().size();
94 }
95 auto elem_type =
96 parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
97 return AwaitType::create(elem_type);
98 } else if (typeName == "RRef") {
99 if (subscript.subscript_exprs().size() != 1) {
100 throw ErrorReport(subscript)
101 << " expected exactly one element type but found "
102 << subscript.subscript_exprs().size();
103 }
104 auto elem_type =
105 parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
106 return RRefType::create(elem_type);
107 } else if (typeName == "Dict" || typeName == "dict") {
108 if (subscript.subscript_exprs().size() != 2) {
109 throw ErrorReport(subscript)
110 << " expected exactly 2 element types but found "
111 << subscript.subscript_exprs().size();
112 }
113 auto key_type = parseTypeFromExprImpl(subscript.subscript_exprs()[0]);
114 auto value_type = parseTypeFromExprImpl(subscript.subscript_exprs()[1]);
115 return DictType::create(key_type, value_type);
116 } else {
117 throw ErrorReport(subscript.range())
118 << "Unknown type constructor " << typeName;
119 }
120 }
121
parseBroadcastList(const Expr & expr) const122 std::optional<std::pair<TypePtr, int32_t>> ScriptTypeParser::parseBroadcastList(
123 const Expr& expr) const {
124 // Alias torch.nn._common_types._size_?_t to BroadcastingList?[int]
125 if (expr.kind() == TK_VAR) {
126 auto var = Var(expr);
127 auto& name = var.name().name();
128 constexpr auto _size_prefix = "_size_";
129 constexpr auto _size_suffix = "_t";
130 constexpr auto _size_n_len = 9; // strlen("_size_X_t")
131 constexpr auto _size_prefix_len = 6; // strlen("_size_");
132 if (name.find(_size_prefix) == 0 && name.length() == _size_n_len &&
133 name.find(_size_suffix) == _size_prefix_len + 1 &&
134 ::isdigit(name[_size_prefix_len])) {
135 int n = name[_size_prefix_len] - '0';
136 return std::pair<TypePtr, int32_t>(ListType::create(IntType::get()), n);
137 }
138 }
139
140 if (expr.kind() != TK_SUBSCRIPT)
141 return std::nullopt;
142 auto subscript = Subscript(expr);
143 if (subscript.value().kind() != TK_VAR)
144 return std::nullopt;
145 auto var = Var(subscript.value());
146 auto subscript_exprs = subscript.subscript_exprs();
147
148 // handle the case where the BroadcastingList is wrapped in a Optional type
149 if (var.name().name() == "Optional") {
150 auto broadcast_list = parseBroadcastList(subscript_exprs[0]);
151 if (broadcast_list) {
152 TypePtr opt_type = OptionalType::create(broadcast_list->first);
153 return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
154 } else {
155 return std::nullopt;
156 }
157 } else if (var.name().name().find("BroadcastingList") != 0) {
158 return std::nullopt;
159 }
160
161 if (subscript_exprs.size() != 1)
162 throw ErrorReport(subscript.subscript_exprs().range())
163 << "BroadcastingList/Optional[BroadcastingList] "
164 "must be subscripted with a type";
165
166 auto typ = subscript_exprs[0];
167 auto len = var.name().name().substr(strlen("BroadcastingList"));
168
169 if (typ.kind() != TK_VAR)
170 throw ErrorReport(subscript.value().range())
171 << "Subscripted type must be a type identifier";
172
173 auto value_name = Var(typ).name().name();
174 if (value_name != "float" && value_name != "int")
175 throw ErrorReport(subscript.value().range())
176 << "Broadcastable lists only supported for int or float";
177
178 auto elem_ptr = string_to_type_lut().find(value_name);
179 AT_ASSERT(elem_ptr != string_to_type_lut().end());
180 TypePtr list_ptr = ListType::create(elem_ptr->second);
181
182 const char* len_c = len.c_str();
183 char* end = nullptr;
184 size_t len_v = strtoull(len_c, &end, 10);
185 if (end != len_c + len.size()) {
186 throw(
187 ErrorReport(subscript.subscript_exprs().range())
188 << "subscript of Broadcastable list must be a positive integer");
189 }
190 return std::pair<TypePtr, int32_t>(list_ptr, len_v);
191 }
192
193 // gets the base type name given namespaces where the types live
194 // turns torch.Tensor -> Tensor, X -> X
parseBaseTypeName(const Expr & expr) const195 std::optional<std::string> ScriptTypeParser::parseBaseTypeName(
196 const Expr& expr) const {
197 switch (expr.kind()) {
198 case TK_VAR: {
199 return Var(expr).name().name();
200 }
201 case TK_NONE: {
202 return "None";
203 }
204 case TK_NONE_TYPE: {
205 return "NoneType";
206 }
207 case '.': {
208 auto select = Select(expr);
209 const std::string& name = select.selector().name();
210 // Special case for torch.Tensor and its' subclasses
211 const std::unordered_set<std::string> tensor_subtypes = {
212 "Tensor",
213 "LongTensor",
214 "FloatTensor",
215 "DoubleTensor",
216 "IntTensor",
217 "ShortTensor",
218 "HalfTensor",
219 "CharTensor",
220 "ByteTensor",
221 "BoolTensor"};
222 if (isTorch(select.value()) && tensor_subtypes.count(name) == 1) {
223 return name;
224 } else {
225 // Otherwise, it's a fully qualified class name
226 return collectQualname(select);
227 }
228 } break;
229 }
230 return std::nullopt;
231 }
232
parseTypeFromExpr(const Expr & expr) const233 TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
234 // the resolver needs to recursively resolve the expression, so to avoid
235 // resolving all type expr subtrees we only use it for the top level
236 // expression and base type names.
237 if (expr.kind() == '|') {
238 auto converted = pep604union_to_union(expr);
239 return parseTypeFromExpr(converted);
240 }
241 if (resolver_) {
242 if (auto typePtr =
243 resolver_->resolveType(expr.range().text().str(), expr.range())) {
244 return typePtr;
245 }
246 }
247 return parseTypeFromExprImpl(expr);
248 }
249
parseTypeFromExprImpl(const Expr & expr) const250 TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const {
251 if (expr.kind() == '|') {
252 auto converted = pep604union_to_union(expr);
253 return parseTypeFromExprImpl(converted);
254 }
255 if (expr.kind() == TK_SUBSCRIPT) {
256 auto subscript = Subscript(expr);
257 auto value_name = parseBaseTypeName(subscript.value());
258 if (!value_name) {
259 throw ErrorReport(subscript.value().range())
260 << "Subscripted type must be a type identifier";
261 }
262 return subscriptToType(*value_name, subscript);
263
264 } else if (expr.kind() == TK_STRINGLITERAL) {
265 const auto& type_name = StringLiteral(expr).text();
266
267 // Check if the type is a custom class. This is done by checking
268 // if type_name starts with "torch.classes."
269 if (type_name.find("torch.classes.") == 0) {
270 auto custom_class_type = getCustomClass("__torch__." + type_name);
271 return custom_class_type;
272 }
273
274 // `torch.cuda.Stream` and `torch.cuda.Event` are aliased as
275 // custom classes of type torch.classes.cuda.Stream and
276 // torch.classes.cuda.Event respectively. Return the respective
277 // custom class types for these two cases.
278 if (type_name.find("torch.cuda.Stream") == 0) {
279 auto custom_class_type =
280 getCustomClass("__torch__.torch.classes.cuda.Stream");
281 return custom_class_type;
282 }
283
284 if (type_name.find("torch.cuda.Event") == 0) {
285 auto custom_class_type =
286 getCustomClass("__torch__.torch.classes.cuda.Event");
287 return custom_class_type;
288 }
289
290 if (resolver_) {
291 if (auto typePtr = resolver_->resolveType(type_name, expr.range())) {
292 return typePtr;
293 }
294 }
295
296 throw ErrorReport(expr) << "Unknown type name '" << type_name << "'";
297 } else if (auto name = parseBaseTypeName(expr)) {
298 auto itr = string_to_type_lut().find(*name);
299 if (itr != string_to_type_lut().end()) {
300 return itr->second;
301 }
302 if (resolver_) {
303 if (auto typePtr = resolver_->resolveType(*name, expr.range())) {
304 return typePtr;
305 }
306 }
307
308 if (auto custom_class_type = getCustomClass(*name)) {
309 return custom_class_type;
310 }
311
312 throw ErrorReport(expr) << "Unknown type name '" << *name << "'";
313 }
314 throw ErrorReport(expr.range())
315 << "Expression of type " << kindToString(expr.kind())
316 << " cannot be used in a type expression";
317 }
318
parseType(const std::string & str)319 TypePtr ScriptTypeParser::parseType(const std::string& str) {
320 Parser p(std::make_shared<Source>(str));
321 return parseTypeFromExpr(p.parseExp());
322 }
323
evaluateDefaults(const SourceRange & r,const std::vector<Expr> & default_types,const std::vector<Expr> & default_exprs)324 std::vector<IValue> ScriptTypeParser::evaluateDefaults(
325 const SourceRange& r,
326 const std::vector<Expr>& default_types,
327 const std::vector<Expr>& default_exprs) {
328 std::vector<IValue> default_values;
329 if (default_exprs.empty())
330 return default_values;
331 // To evaluate the default expressions, we create a graph with no inputs,
332 // and whose returns are the default values we need.
333 // We then run constant prop on this graph and check the results are
334 // constant. This approach avoids having to have separate handling of
335 // default arguments from standard expressions by piecing together existing
336 // machinery for graph generation, constant propagation, and constant
337 // extraction.
338 auto tuple_type = Subscript::create(
339 r,
340 Var::create(r, Ident::create(r, "Tuple")),
341 List<Expr>::create(r, default_types));
342 auto blank_decl = Decl::create(
343 r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
344
345 auto tuple_expr =
346 TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
347 auto ret = Return::create(r, tuple_expr);
348 auto def = Def::create(
349 r,
350 Ident::create(r, "defaults"),
351 blank_decl,
352 List<Stmt>::create(r, {ret}));
353
354 CompilationUnit cu;
355 cu.define(
356 std::nullopt,
357 /*properties=*/{},
358 /*propResolvers=*/{},
359 {def},
360 {resolver_},
361 nullptr);
362 Stack stack;
363 // XXX: We need to turn optimization off here because otherwise we try to
364 // recursively initialize stuff in DecomposeOps.
365 GraphOptimizerEnabledGuard guard(false);
366 auto& f = cu.get_function(def.name().name());
367 auto* gf = dynamic_cast<GraphFunction*>(&f);
368 TORCH_INTERNAL_ASSERT(gf);
369 // 2024.08.14: Since we are starting to deprecate Torchscript usages,
370 // we are going to log all the calls for GraphFunction::run. The logging was
371 // noisy we also call GraphFunction::run for the default value evaluation
372 // which generates a lot of useless log samples. Therefore as a workaround we
373 // just directly use the executor API which avoids this placing producing
374 // un-necessary log entries.
375 gf->get_executor().run(stack);
376 return stack.at(0).toTupleRef().elements().vec();
377 }
378
parseArgsFromDecl(const Decl & decl,bool skip_self)379 std::vector<Argument> ScriptTypeParser::parseArgsFromDecl(
380 const Decl& decl,
381 bool skip_self) {
382 auto params_begin = decl.params().begin();
383 auto params_end = decl.params().end();
384 if (skip_self) {
385 ++params_begin;
386 }
387 std::vector<Argument> retval;
388
389 std::vector<Expr> default_types;
390 std::vector<Expr> default_exprs;
391 // gather any non-empty default arguments
392 for (auto it = params_begin; it != params_end; ++it) {
393 auto param = *it;
394 auto def = param.defaultValue();
395 if (def.present()) {
396 if (!param.type().present()) {
397 // We require explicit type-hints for default expressions.
398 // If param doesn't have a type, we could default to "Tensor",
399 // just like what happens in the Python frontend.
400 // However here things are a bit more complicated, because
401 // default expressions are evaluated using a custom-built
402 // graph, and error messages coming out of that in case
403 // the type doesn't match the value are quite obscure.
404 throw ErrorReport(param.range())
405 << "Keyword arguments with defaults need to be type-hinted (TorchScript C++ frontend)";
406 }
407 default_types.emplace_back(param.type().get());
408 default_exprs.emplace_back(def.get());
409 }
410 }
411
412 auto default_values =
413 evaluateDefaults(decl.range(), default_types, default_exprs);
414
415 auto defaults_it = default_values.begin();
416 for (auto it = params_begin; it != params_end; ++it) {
417 auto decl_arg = *it;
418
419 TypePtr type;
420 std::optional<int32_t> N = std::nullopt;
421 if (!decl_arg.type().present()) {
422 // If this param doesn't have a type, default to "tensor"
423 type = TensorType::getInferred();
424 } else {
425 // BroadcastList list can only appear at the argument level
426 Expr type_expr = decl_arg.type().get();
427 if (auto maybe_broad_list = parseBroadcastList(type_expr)) {
428 type = maybe_broad_list->first;
429 N = maybe_broad_list->second;
430 } else {
431 type = parseTypeFromExpr(decl_arg.type().get());
432 }
433 }
434 std::optional<IValue> default_value = std::nullopt;
435 if (decl_arg.defaultValue().present()) {
436 default_value = *defaults_it++;
437 }
438 auto arg = Argument(
439 decl_arg.ident().name(),
440 type,
441 N,
442 default_value,
443 decl_arg.kwarg_only(),
444 /*alias_info=*/std::nullopt);
445 retval.push_back(arg);
446 }
447 return retval;
448 }
449
parseReturnFromDecl(const Decl & decl)450 std::vector<Argument> ScriptTypeParser::parseReturnFromDecl(const Decl& decl) {
451 // we represent no annoation on a return type as having no values in the
452 // schema's return() list
453 // in emitReturn we take the actual return value to be the value of the
454 // return statement if no one was provided here
455 if (!decl.return_type().present())
456 return {};
457
458 if (parseBroadcastList(decl.return_type().get()))
459 throw ErrorReport(decl.return_type().range())
460 << "Broadcastable lists cannot appear as a return type";
461
462 TypePtr parsed_type;
463 Expr type_expr = decl.return_type().get();
464 parsed_type = parseTypeFromExpr(type_expr);
465 return {Argument(
466 "",
467 parsed_type,
468 /*N =*/std::nullopt,
469 /*default_value =*/std::nullopt,
470 /*kwarg_only =*/false)};
471 }
parseSchemaFromDef(const Def & def,bool skip_self)472 FunctionSchema ScriptTypeParser::parseSchemaFromDef(
473 const Def& def,
474 bool skip_self) {
475 const auto name = def.name().name();
476 std::vector<Argument> args = parseArgsFromDecl(def.decl(), skip_self);
477 std::vector<Argument> returns = parseReturnFromDecl(def.decl());
478 return FunctionSchema(
479 name, "", std::move(args), std::move(returns), false, false);
480 }
481
parseClassConstant(const Assign & assign)482 c10::IValue ScriptTypeParser::parseClassConstant(const Assign& assign) {
483 if (assign.lhs().kind() != TK_VAR) {
484 throw ErrorReport(assign.range())
485 << "Expected to a variable for class constant";
486 }
487 if (!assign.type().present()) {
488 throw ErrorReport(assign.range())
489 << "Expected a type to present for class constant";
490 }
491 const auto final_type = assign.type().get();
492 auto expr = assign.rhs().get();
493 if (final_type.kind() != TK_SUBSCRIPT) {
494 throw ErrorReport(assign.range())
495 << "Expected subscripted type for class constant";
496 }
497 auto subscript = Subscript(final_type);
498 auto value_name = parseBaseTypeName(subscript.value());
499 if (!value_name) {
500 throw ErrorReport(subscript.value().range())
501 << "Subscripted type must be a type identifier";
502 }
503 if (*value_name != "Final") {
504 throw ErrorReport(subscript.range())
505 << "Base type must be Final for class constant";
506 }
507 if (subscript.subscript_exprs().size() != 1) {
508 throw ErrorReport(subscript)
509 << " expected exactly one element type but found "
510 << subscript.subscript_exprs().size();
511 }
512 auto type = *subscript.subscript_exprs().begin();
513 auto default_val = evaluateDefaults(expr.range(), {type}, {expr});
514 return *default_val.begin();
515 }
516
517 } // namespace torch::jit
518