xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/schema_type_parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/schema_type_parser.h>
2 
3 #include <ATen/core/alias_info.h>
4 #include <ATen/core/jit_type.h>
5 #include <ATen/core/symbol.h>
6 #include <ATen/core/type_factory.h>
7 #include <torch/csrc/jit/frontend/lexer.h>
8 #include <torch/csrc/jit/frontend/parse_string_literal.h>
9 #include <torch/custom_class.h>
10 #include <string>
11 
12 using c10::AliasInfo;
13 using c10::AwaitType;
14 using c10::BoolType;
15 using c10::CapsuleType;
16 using c10::ComplexType;
17 using c10::DeviceObjType;
18 using c10::DictType;
19 using c10::FloatType;
20 using c10::FutureType;
21 using c10::GeneratorType;
22 using c10::IntType;
23 using c10::LayoutType;
24 using c10::ListType;
25 using c10::MemoryFormatType;
26 using c10::NoneType;
27 using c10::NumberType;
28 using c10::QSchemeType;
29 using c10::QuantizerType;
30 using c10::RRefType;
31 using c10::ScalarTypeType;
32 using c10::StorageType;
33 using c10::StreamObjType;
34 using c10::StringType;
35 using c10::Symbol;
36 using c10::SymIntType;
37 using c10::TensorType;
38 using c10::TupleType;
39 using c10::UnionType;
40 using c10::VarType;
41 
42 namespace torch::jit {
43 
parseBaseType()44 TypePtr SchemaTypeParser::parseBaseType() {
45   static std::unordered_map<std::string, TypePtr> type_map = {
46       {"Generator", c10::TypeFactory::get<GeneratorType>()},
47       {"Dimname", c10::TypeFactory::get<StringType>()},
48       {"ScalarType", c10::TypeFactory::get<ScalarTypeType>()},
49       {"Layout", c10::TypeFactory::get<LayoutType>()},
50       {"MemoryFormat", c10::TypeFactory::get<MemoryFormatType>()},
51       {"Storage", c10::TypeFactory::get<StorageType>()},
52       {"QScheme", c10::TypeFactory::get<QSchemeType>()},
53       {"Quantizer", c10::TypeFactory::get<QuantizerType>()},
54       {"ConstQuantizerPtr",
55        c10::TypeFactory::get<IntType>()}, // TODO This type should be removed
56                                           // from the schema parser, it should
57                                           // use the custom class mechanism
58                                           // instead. @jerryzh
59       {"Device", c10::TypeFactory::get<DeviceObjType>()},
60       {"DeviceIndex", c10::TypeFactory::get<IntType>()},
61       {"Stream", c10::TypeFactory::get<StreamObjType>()},
62       {"Scalar", c10::TypeFactory::get<NumberType>()},
63       {"str", c10::TypeFactory::get<StringType>()},
64       {"float", c10::TypeFactory::get<FloatType>()},
65       {"complex", c10::TypeFactory::get<ComplexType>()},
66       {"int", c10::TypeFactory::get<IntType>()},
67       {"SymInt", c10::TypeFactory::get<SymIntType>()},
68       {"bool", c10::TypeFactory::get<BoolType>()},
69       {"None", c10::TypeFactory::get<NoneType>()},
70       {"NoneType", c10::TypeFactory::get<NoneType>()},
71       {"Capsule", c10::TypeFactory::get<CapsuleType>()},
72       {"Any", c10::TypeFactory::get<c10::AnyType>()},
73       {"AnyClassType", c10::TypeFactory::get<c10::AnyClassType>()},
74       {"AnyEnumType", c10::TypeFactory::get<c10::AnyEnumType>()},
75   };
76   auto tok = L.cur();
77   if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) {
78     L.expect(TK_IDENT);
79   }
80   std::string text = tok.text();
81 
82   auto it = type_map.find(text);
83   if (it == type_map.end()) {
84     if (allow_typevars_ && !text.empty() && islower(text[0])) {
85       // lower case identifiers that are not otherwise valid types
86       // are treated as type variables
87       return c10::TypeFactory::createNamed<VarType>(text);
88     }
89     if (text == "double") {
90       throw(
91           ErrorReport(tok.range)
92           << "Use `float` instead of `double` in an operator's schema string. "
93              "`float` in schema corresponds to the double type in C++");
94     }
95     if (text == "int64_t") {
96       throw(
97           ErrorReport(tok.range)
98           << "Use `SymInt` or `int` instead of `int64_t` in an operator's schema string. "
99              "`SymInt` corresponds to c10::SymInt in C++ while `int` in schema corresponds "
100              "to the int64_t type in C++.");
101     }
102     throw(
103         ErrorReport(tok.range)
104         << "unknown type specifier. Common valid schema types include "
105            "Tensor, SymInt, int, float, bool, Scalar; "
106            "for a full list, please see "
107            "https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func ");
108   }
109   return it->second;
110 }
111 
112 // Examples:
113 // Tensor(a) // Tensor is in set a
114 // Tensor(a!) // it is also written to
115 // Tensor!  // shorthand for Tensor(fresh_identifier!)
116 // Tensor(a! -> a|b) // Tensor is in set a, written to,
117 //                      and after the write is in set a AND b.
parseAliasAnnotation()118 std::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
119   AliasInfo alias_info;
120   if (L.nextIf('(')) {
121     // optional 'alias set annotation'
122     parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
123       if (L.nextIf('*')) {
124         alias_info.addBeforeSet(AliasInfo::wildcardSet());
125 
126         // If we found a wildcard, ignore all subsequent annotations
127       } else if (!alias_info.isWildcardBefore()) {
128         alias_info.addBeforeSet(
129             Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
130       }
131     });
132     if (L.nextIf('!')) {
133       alias_info.setIsWrite(true);
134     }
135     if (L.nextIf(TK_ARROW)) {
136       // optional 'alias set annotation'
137       parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
138         if (L.nextIf('*')) {
139           alias_info.addAfterSet(AliasInfo::wildcardSet());
140 
141           // If we found a wildcard, ignore all subsequent annotations
142         } else if (!alias_info.isWildcardAfter()) {
143           alias_info.addAfterSet(
144               Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
145         }
146       });
147     } else {
148       // We didn't encounter an ->, so assume the "after set" is identical
149       // to the "before set"
150       AT_ASSERT(alias_info.afterSets().empty());
151       for (const auto& set : alias_info.beforeSets()) {
152         alias_info.addAfterSet(set);
153       }
154     }
155     L.expect(')');
156   } else if (L.nextIf('!')) {
157     alias_info.addBeforeSet(
158         Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
159     alias_info.setIsWrite(true);
160   } else {
161     return std::nullopt;
162   }
163 
164   return alias_info;
165 }
166 
parseTensorDType(const std::string & dtype)167 std::optional<at::ScalarType> SchemaTypeParser::parseTensorDType(
168     const std::string& dtype) {
169 #define DEFINE_SCALAR_TYPE(_1, n) {#n, at::ScalarType::n},
170 
171   static std::unordered_map<std::string, at::ScalarType> type_map = {
172       AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
173 
174 #undef DEFINE_SCALAR_TYPE
175   auto type = type_map.find(dtype);
176   if (type != type_map.end()) {
177     return type->second;
178   }
179   return std::nullopt;
180 }
181 
tryToParseDeviceType()182 std::optional<c10::Device> SchemaTypeParser::tryToParseDeviceType() {
183   L.expect('=');
184   const std::string& dev = L.expect(TK_IDENT).text();
185 
186   if (dev == "cpu") {
187     return c10::Device(at::kCPU);
188   }
189 
190   if (dev == "cuda" || dev == "hpu") {
191     c10::DeviceIndex device_idx = -1;
192     if (L.cur().kind == ':') {
193       L.expect(':');
194       const std::string& num = L.expect(TK_NUMBER).text();
195       try {
196         device_idx = static_cast<c10::DeviceIndex>(std::stoi(num));
197       } catch (const std::invalid_argument& e) {
198         throw(
199             ErrorReport(L.cur())
200             << "Device index cannot be converted to integer");
201       } catch (const std::out_of_range& e) {
202         throw(ErrorReport(L.cur()) << "Device index is too long");
203       }
204     }
205     if (dev == "cuda") {
206       return c10::Device(at::kCUDA, device_idx);
207     } else {
208       return c10::Device(at::kHPU, device_idx);
209     }
210   }
211 
212   throw(ErrorReport(L.cur()) << "cannot parse device type '" << dev << "'\n");
213 }
214 
tryToParseRequiresGrad()215 std::optional<bool> SchemaTypeParser::tryToParseRequiresGrad() {
216   L.expect('=');
217   const std::string& num = L.expect(TK_NUMBER).text();
218   try {
219     return (bool)std::stoi(num);
220   } catch (const std::invalid_argument& e) {
221     throw(
222         ErrorReport(L.cur())
223         << "Field requires_grad cannot be converted to integer");
224   } catch (const std::out_of_range& e) {
225     throw(ErrorReport(L.cur()) << "Field requires_grad is too long");
226   }
227 }
228 
parseRefinedTensor()229 TypePtr SchemaTypeParser::parseRefinedTensor() {
230   auto maybe_dtype = parseTensorDType(L.expect(TK_IDENT).text());
231   AT_ASSERT(maybe_dtype);
232   at::ScalarType dtype = *maybe_dtype;
233   TypePtr ptr;
234   L.expect('(');
235   TypePtr tensor_type;
236   std::optional<c10::Device> device;
237   std::optional<bool> requires_grad;
238   // Parse a type with either no ranks, known ranks with sizes, ranks with
239   // unknown sizes, a mix of ranks with known and unknown sizes, or ranks with
240   // known sizes and strides. The type might also have requires_grad and/or
241   // device option. Examples of types we're handling here:
242   //   Long(10, 8, 6, strides=[48, 6, 1], requires_grad=0, device=cuda:1)
243   //   Float(10, *, 20, device=cuda:1)
244   //   Float(requires_grad=1)
245   std::vector<std::optional<int64_t>> dims;
246   bool seen_strides = false;
247   std::vector<int64_t> strides;
248   parseList(TK_NOTHING, ',', ')', [&] {
249     // Extra handling for options like 'device' and 'requires_grad'
250     if (L.cur().kind == TK_IDENT && L.cur().text() != "SS") {
251       const std::string& field = L.expect(TK_IDENT).text();
252       if (field == "device") {
253         auto parsed_device = tryToParseDeviceType();
254         if (parsed_device.has_value()) {
255           if (device.has_value()) {
256             throw(ErrorReport(L.cur()) << "'device' is specified twice");
257           }
258           device = parsed_device;
259         }
260         return;
261       }
262       if (field == "requires_grad") {
263         auto parsed_requires_grad = tryToParseRequiresGrad();
264         if (parsed_requires_grad.has_value()) {
265           if (requires_grad.has_value()) {
266             throw(ErrorReport(L.cur()) << "'requires_grad' is specified twice");
267           }
268           requires_grad = parsed_requires_grad;
269         }
270         return;
271       }
272       if (field == "strides") {
273         seen_strides = true;
274         L.expect('=');
275         parseList('[', ',', ']', [&] {
276           const std::string& num = L.expect(TK_NUMBER).text();
277           try {
278             auto stride = std::stoll(num);
279             strides.push_back(stride);
280           } catch (const std::invalid_argument& e) {
281             throw(
282                 ErrorReport(L.cur())
283                 << "The stride value cannot be converted to int");
284           } catch (const std::out_of_range& e) {
285             throw(ErrorReport(L.cur()) << "The stride is too big");
286           }
287         });
288         return;
289       }
290       throw(ErrorReport(L.cur()) << "Unexpected specifier '" << field << "'");
291     }
292     if (device.has_value() || requires_grad.has_value()) {
293       throw(
294           ErrorReport(L.cur())
295           << "'device' and 'requires_grad' should come after dimensions in the type specification");
296     }
297 
298     // Parsing ranks, supports mix of sized and unsized ranks, or, just strided
299     // ranks
300     if (L.cur().kind == '*') {
301       dims.emplace_back(std::nullopt);
302       L.next();
303       if (L.cur().kind == ':') {
304         throw(
305             ErrorReport(L.cur()) << "Strides for unsized ranks not supported");
306       }
307       return;
308     }
309     bool shape_symbol = false;
310     if (L.cur().kind == TK_IDENT && L.cur().text() == "SS") {
311       L.next();
312       L.expect('(');
313       L.expect('-');
314       shape_symbol = true;
315     }
316     const std::string& num = L.expect(TK_NUMBER).text();
317     int64_t dim = 0;
318     try {
319       dim = std::stoll(num);
320     } catch (const std::invalid_argument& e) {
321       throw(ErrorReport(L.cur()) << "The number can't be converted to int");
322     } catch (const std::out_of_range& e) {
323       throw(ErrorReport(L.cur()) << "Number is too big");
324     }
325     if (shape_symbol) {
326       L.expect(')');
327       dim = -dim;
328     }
329     dims.emplace_back(dim);
330   });
331   if (seen_strides) {
332     at::IntArrayRef strides_ref(strides);
333     if (strides.size() != dims.size()) {
334       // note: mixing unsized ranks and ranks with strides will always trigger
335       // this
336       throw(
337           ErrorReport(L.cur())
338           << "Strides info is specified for some but not for all dimensions");
339     }
340     ptr = at::TensorType::create(
341         dtype,
342         device,
343         c10::VaryingShape<int64_t>(dims),
344         c10::VaryingShape<int64_t>(strides),
345         requires_grad);
346   } else {
347     ptr = at::TensorType::create(
348         dtype,
349         device,
350         c10::VaryingShape<int64_t>(dims),
351         c10::VaryingShape<int64_t>(dims.size()),
352         requires_grad);
353   }
354   return ptr;
355 }
356 
parseType()357 std::pair<TypePtr, std::optional<AliasInfo>> SchemaTypeParser::parseType() {
358   auto r = parseFakeAndRealType();
359   return std::make_pair(std::move(std::get<0>(r)), std::move(std::get<2>(r)));
360 }
361 
362 std::tuple</*fake*/ TypePtr, /*real*/ TypePtr, std::optional<AliasInfo>>
parseFakeAndRealType()363 SchemaTypeParser::parseFakeAndRealType() {
364   TypePtr fake_value;
365   TypePtr real_value;
366   std::optional<AliasInfo> alias_info;
367   // Tuple type
368   if (L.cur().kind == '(') {
369     std::vector<TypePtr> types;
370     parseList('(', ',', ')', [&] {
371       auto r = parseType();
372       types.push_back(std::move(r.first));
373       if (alias_info && r.second) {
374         alias_info->addContainedType(std::move(*r.second));
375       }
376     });
377     fake_value = real_value =
378         c10::TypeFactory::create<TupleType>(std::move(types));
379   } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
380     L.next(); // Future
381     L.expect('(');
382     auto p = parseType();
383     auto subtype = std::move(p.first);
384     auto subalias = std::move(p.second);
385     L.expect(')');
386     fake_value = real_value = c10::TypeFactory::create<FutureType>(subtype);
387   } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Await") {
388     L.next(); // Await
389     L.expect('(');
390     auto p = parseType();
391     auto subtype = std::move(p.first);
392     auto subalias = std::move(p.second);
393     L.expect(')');
394     fake_value = real_value = c10::TypeFactory::create<AwaitType>(subtype);
395   } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
396     L.next(); // RRef
397     L.expect('(');
398     auto p = parseType();
399     auto subtype = std::move(p.first);
400     auto subalias = std::move(p.second);
401     L.expect(')');
402     fake_value = real_value = c10::TypeFactory::create<RRefType>(subtype);
403   } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
404     L.next();
405     fake_value = real_value = c10::TypeFactory::get<TensorType>();
406     alias_info = parseAliasAnnotation();
407   } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
408     L.next();
409     L.expect('(');
410     auto key_type = parseType().first;
411     L.expect(',');
412     auto value_type = parseType().first;
413     L.expect(')');
414     alias_info = parseAliasAnnotation();
415     fake_value = real_value =
416         c10::TypeFactory::create<DictType>(key_type, value_type);
417   } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
418     L.next();
419     L.expect('(');
420     std::vector<TypePtr> types;
421     types.emplace_back(parseType().first);
422     while (L.cur().kind != ')') {
423       L.expect(',');
424       types.emplace_back(parseType().first);
425     }
426     L.expect(')');
427     alias_info = parseAliasAnnotation();
428     fake_value = real_value =
429         c10::TypeFactory::create<c10::UnionType>(std::move(types));
430   } else if (
431       complete_tensor_types && L.cur().kind == TK_IDENT &&
432       parseTensorDType(L.cur().text())) {
433     fake_value = real_value = parseRefinedTensor();
434     alias_info = parseAliasAnnotation();
435   } else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
436     L.next();
437     L.expect('.');
438     auto torch_tok = L.expect(TK_IDENT);
439     if (torch_tok.text() != "torch") {
440       throw(
441           ErrorReport(torch_tok.range)
442           << "Expected classes namespace but got " << torch_tok.text());
443     }
444     L.expect('.');
445     auto classes_tok = L.expect(TK_IDENT);
446     if (classes_tok.text() != "classes") {
447       throw(
448           ErrorReport(classes_tok.range)
449           << "Expected classes namespace but got " << classes_tok.text());
450     }
451     L.expect('.');
452     auto ns_tok = L.expect(TK_IDENT);
453     L.expect('.');
454     auto class_tok = L.expect(TK_IDENT);
455     fake_value = real_value = getCustomClass(
456         std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
457         class_tok.text());
458     if (!fake_value) {
459       throw(
460           ErrorReport(class_tok.range) << "Unknown custom class type "
461                                        << ns_tok.text() + "." + class_tok.text()
462                                        << ". Please ensure it is registered.");
463     }
464   } else {
465     real_value = parseBaseType();
466     if (real_value->kind() == ScalarTypeType::Kind ||
467         real_value->kind() == MemoryFormatType::Kind ||
468         real_value->kind() == LayoutType::Kind ||
469         real_value->kind() == SymIntType::Kind) {
470       fake_value = c10::TypeFactory::get<IntType>();
471     } else {
472       fake_value = real_value;
473     }
474     alias_info = parseAliasAnnotation();
475   }
476   while (true) {
477     if (L.cur().kind == '[' && L.lookahead().kind == ']') {
478       L.next(); // [
479       L.next(); // ]
480       fake_value = c10::TypeFactory::create<ListType>(fake_value);
481       real_value = c10::TypeFactory::create<ListType>(real_value);
482       auto container = parseAliasAnnotation();
483       if (alias_info) {
484         if (!container) {
485           container = std::optional<AliasInfo>(AliasInfo());
486           container->setIsWrite(alias_info->isWrite());
487         }
488         container->addContainedType(std::move(*alias_info));
489       }
490       alias_info = std::move(container);
491     } else if (L.nextIf('?')) {
492       fake_value = c10::OptionalType::get(fake_value);
493       real_value = c10::OptionalType::get(real_value);
494     } else {
495       break;
496     }
497   }
498   return std::make_tuple(
499       std::move(fake_value), std::move(real_value), std::move(alias_info));
500 }
501 
parseList(int begin,int sep,int end,c10::function_ref<void ()> callback)502 void SchemaTypeParser::parseList(
503     int begin,
504     int sep,
505     int end,
506     c10::function_ref<void()> callback) {
507   auto r = L.cur().range;
508   if (begin != TK_NOTHING)
509     L.expect(begin);
510   if (L.cur().kind != end) {
511     do {
512       callback();
513     } while (L.nextIf(sep));
514   }
515   if (end != TK_NOTHING)
516     L.expect(end);
517 }
518 
519 } // namespace torch::jit
520