1 #include <torch/csrc/jit/frontend/schema_type_parser.h>
2
3 #include <ATen/core/alias_info.h>
4 #include <ATen/core/jit_type.h>
5 #include <ATen/core/symbol.h>
6 #include <ATen/core/type_factory.h>
7 #include <torch/csrc/jit/frontend/lexer.h>
8 #include <torch/csrc/jit/frontend/parse_string_literal.h>
9 #include <torch/custom_class.h>
10 #include <string>
11
12 using c10::AliasInfo;
13 using c10::AwaitType;
14 using c10::BoolType;
15 using c10::CapsuleType;
16 using c10::ComplexType;
17 using c10::DeviceObjType;
18 using c10::DictType;
19 using c10::FloatType;
20 using c10::FutureType;
21 using c10::GeneratorType;
22 using c10::IntType;
23 using c10::LayoutType;
24 using c10::ListType;
25 using c10::MemoryFormatType;
26 using c10::NoneType;
27 using c10::NumberType;
28 using c10::QSchemeType;
29 using c10::QuantizerType;
30 using c10::RRefType;
31 using c10::ScalarTypeType;
32 using c10::StorageType;
33 using c10::StreamObjType;
34 using c10::StringType;
35 using c10::Symbol;
36 using c10::SymIntType;
37 using c10::TensorType;
38 using c10::TupleType;
39 using c10::UnionType;
40 using c10::VarType;
41
42 namespace torch::jit {
43
parseBaseType()44 TypePtr SchemaTypeParser::parseBaseType() {
45 static std::unordered_map<std::string, TypePtr> type_map = {
46 {"Generator", c10::TypeFactory::get<GeneratorType>()},
47 {"Dimname", c10::TypeFactory::get<StringType>()},
48 {"ScalarType", c10::TypeFactory::get<ScalarTypeType>()},
49 {"Layout", c10::TypeFactory::get<LayoutType>()},
50 {"MemoryFormat", c10::TypeFactory::get<MemoryFormatType>()},
51 {"Storage", c10::TypeFactory::get<StorageType>()},
52 {"QScheme", c10::TypeFactory::get<QSchemeType>()},
53 {"Quantizer", c10::TypeFactory::get<QuantizerType>()},
54 {"ConstQuantizerPtr",
55 c10::TypeFactory::get<IntType>()}, // TODO This type should be removed
56 // from the schema parser, it should
57 // use the custom class mechanism
58 // instead. @jerryzh
59 {"Device", c10::TypeFactory::get<DeviceObjType>()},
60 {"DeviceIndex", c10::TypeFactory::get<IntType>()},
61 {"Stream", c10::TypeFactory::get<StreamObjType>()},
62 {"Scalar", c10::TypeFactory::get<NumberType>()},
63 {"str", c10::TypeFactory::get<StringType>()},
64 {"float", c10::TypeFactory::get<FloatType>()},
65 {"complex", c10::TypeFactory::get<ComplexType>()},
66 {"int", c10::TypeFactory::get<IntType>()},
67 {"SymInt", c10::TypeFactory::get<SymIntType>()},
68 {"bool", c10::TypeFactory::get<BoolType>()},
69 {"None", c10::TypeFactory::get<NoneType>()},
70 {"NoneType", c10::TypeFactory::get<NoneType>()},
71 {"Capsule", c10::TypeFactory::get<CapsuleType>()},
72 {"Any", c10::TypeFactory::get<c10::AnyType>()},
73 {"AnyClassType", c10::TypeFactory::get<c10::AnyClassType>()},
74 {"AnyEnumType", c10::TypeFactory::get<c10::AnyEnumType>()},
75 };
76 auto tok = L.cur();
77 if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) {
78 L.expect(TK_IDENT);
79 }
80 std::string text = tok.text();
81
82 auto it = type_map.find(text);
83 if (it == type_map.end()) {
84 if (allow_typevars_ && !text.empty() && islower(text[0])) {
85 // lower case identifiers that are not otherwise valid types
86 // are treated as type variables
87 return c10::TypeFactory::createNamed<VarType>(text);
88 }
89 if (text == "double") {
90 throw(
91 ErrorReport(tok.range)
92 << "Use `float` instead of `double` in an operator's schema string. "
93 "`float` in schema corresponds to the double type in C++");
94 }
95 if (text == "int64_t") {
96 throw(
97 ErrorReport(tok.range)
98 << "Use `SymInt` or `int` instead of `int64_t` in an operator's schema string. "
99 "`SymInt` corresponds to c10::SymInt in C++ while `int` in schema corresponds "
100 "to the int64_t type in C++.");
101 }
102 throw(
103 ErrorReport(tok.range)
104 << "unknown type specifier. Common valid schema types include "
105 "Tensor, SymInt, int, float, bool, Scalar; "
106 "for a full list, please see "
107 "https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func ");
108 }
109 return it->second;
110 }
111
112 // Examples:
113 // Tensor(a) // Tensor is in set a
114 // Tensor(a!) // it is also written to
115 // Tensor! // shorthand for Tensor(fresh_identifier!)
116 // Tensor(a! -> a|b) // Tensor is in set a, written to,
117 // and after the write is in set a AND b.
parseAliasAnnotation()118 std::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
119 AliasInfo alias_info;
120 if (L.nextIf('(')) {
121 // optional 'alias set annotation'
122 parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
123 if (L.nextIf('*')) {
124 alias_info.addBeforeSet(AliasInfo::wildcardSet());
125
126 // If we found a wildcard, ignore all subsequent annotations
127 } else if (!alias_info.isWildcardBefore()) {
128 alias_info.addBeforeSet(
129 Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
130 }
131 });
132 if (L.nextIf('!')) {
133 alias_info.setIsWrite(true);
134 }
135 if (L.nextIf(TK_ARROW)) {
136 // optional 'alias set annotation'
137 parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
138 if (L.nextIf('*')) {
139 alias_info.addAfterSet(AliasInfo::wildcardSet());
140
141 // If we found a wildcard, ignore all subsequent annotations
142 } else if (!alias_info.isWildcardAfter()) {
143 alias_info.addAfterSet(
144 Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
145 }
146 });
147 } else {
148 // We didn't encounter an ->, so assume the "after set" is identical
149 // to the "before set"
150 AT_ASSERT(alias_info.afterSets().empty());
151 for (const auto& set : alias_info.beforeSets()) {
152 alias_info.addAfterSet(set);
153 }
154 }
155 L.expect(')');
156 } else if (L.nextIf('!')) {
157 alias_info.addBeforeSet(
158 Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
159 alias_info.setIsWrite(true);
160 } else {
161 return std::nullopt;
162 }
163
164 return alias_info;
165 }
166
parseTensorDType(const std::string & dtype)167 std::optional<at::ScalarType> SchemaTypeParser::parseTensorDType(
168 const std::string& dtype) {
169 #define DEFINE_SCALAR_TYPE(_1, n) {#n, at::ScalarType::n},
170
171 static std::unordered_map<std::string, at::ScalarType> type_map = {
172 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
173
174 #undef DEFINE_SCALAR_TYPE
175 auto type = type_map.find(dtype);
176 if (type != type_map.end()) {
177 return type->second;
178 }
179 return std::nullopt;
180 }
181
tryToParseDeviceType()182 std::optional<c10::Device> SchemaTypeParser::tryToParseDeviceType() {
183 L.expect('=');
184 const std::string& dev = L.expect(TK_IDENT).text();
185
186 if (dev == "cpu") {
187 return c10::Device(at::kCPU);
188 }
189
190 if (dev == "cuda" || dev == "hpu") {
191 c10::DeviceIndex device_idx = -1;
192 if (L.cur().kind == ':') {
193 L.expect(':');
194 const std::string& num = L.expect(TK_NUMBER).text();
195 try {
196 device_idx = static_cast<c10::DeviceIndex>(std::stoi(num));
197 } catch (const std::invalid_argument& e) {
198 throw(
199 ErrorReport(L.cur())
200 << "Device index cannot be converted to integer");
201 } catch (const std::out_of_range& e) {
202 throw(ErrorReport(L.cur()) << "Device index is too long");
203 }
204 }
205 if (dev == "cuda") {
206 return c10::Device(at::kCUDA, device_idx);
207 } else {
208 return c10::Device(at::kHPU, device_idx);
209 }
210 }
211
212 throw(ErrorReport(L.cur()) << "cannot parse device type '" << dev << "'\n");
213 }
214
tryToParseRequiresGrad()215 std::optional<bool> SchemaTypeParser::tryToParseRequiresGrad() {
216 L.expect('=');
217 const std::string& num = L.expect(TK_NUMBER).text();
218 try {
219 return (bool)std::stoi(num);
220 } catch (const std::invalid_argument& e) {
221 throw(
222 ErrorReport(L.cur())
223 << "Field requires_grad cannot be converted to integer");
224 } catch (const std::out_of_range& e) {
225 throw(ErrorReport(L.cur()) << "Field requires_grad is too long");
226 }
227 }
228
parseRefinedTensor()229 TypePtr SchemaTypeParser::parseRefinedTensor() {
230 auto maybe_dtype = parseTensorDType(L.expect(TK_IDENT).text());
231 AT_ASSERT(maybe_dtype);
232 at::ScalarType dtype = *maybe_dtype;
233 TypePtr ptr;
234 L.expect('(');
235 TypePtr tensor_type;
236 std::optional<c10::Device> device;
237 std::optional<bool> requires_grad;
238 // Parse a type with either no ranks, known ranks with sizes, ranks with
239 // unknown sizes, a mix of ranks with known and unknown sizes, or ranks with
240 // known sizes and strides. The type might also have requires_grad and/or
241 // device option. Examples of types we're handling here:
242 // Long(10, 8, 6, strides=[48, 6, 1], requires_grad=0, device=cuda:1)
243 // Float(10, *, 20, device=cuda:1)
244 // Float(requires_grad=1)
245 std::vector<std::optional<int64_t>> dims;
246 bool seen_strides = false;
247 std::vector<int64_t> strides;
248 parseList(TK_NOTHING, ',', ')', [&] {
249 // Extra handling for options like 'device' and 'requires_grad'
250 if (L.cur().kind == TK_IDENT && L.cur().text() != "SS") {
251 const std::string& field = L.expect(TK_IDENT).text();
252 if (field == "device") {
253 auto parsed_device = tryToParseDeviceType();
254 if (parsed_device.has_value()) {
255 if (device.has_value()) {
256 throw(ErrorReport(L.cur()) << "'device' is specified twice");
257 }
258 device = parsed_device;
259 }
260 return;
261 }
262 if (field == "requires_grad") {
263 auto parsed_requires_grad = tryToParseRequiresGrad();
264 if (parsed_requires_grad.has_value()) {
265 if (requires_grad.has_value()) {
266 throw(ErrorReport(L.cur()) << "'requires_grad' is specified twice");
267 }
268 requires_grad = parsed_requires_grad;
269 }
270 return;
271 }
272 if (field == "strides") {
273 seen_strides = true;
274 L.expect('=');
275 parseList('[', ',', ']', [&] {
276 const std::string& num = L.expect(TK_NUMBER).text();
277 try {
278 auto stride = std::stoll(num);
279 strides.push_back(stride);
280 } catch (const std::invalid_argument& e) {
281 throw(
282 ErrorReport(L.cur())
283 << "The stride value cannot be converted to int");
284 } catch (const std::out_of_range& e) {
285 throw(ErrorReport(L.cur()) << "The stride is too big");
286 }
287 });
288 return;
289 }
290 throw(ErrorReport(L.cur()) << "Unexpected specifier '" << field << "'");
291 }
292 if (device.has_value() || requires_grad.has_value()) {
293 throw(
294 ErrorReport(L.cur())
295 << "'device' and 'requires_grad' should come after dimensions in the type specification");
296 }
297
298 // Parsing ranks, supports mix of sized and unsized ranks, or, just strided
299 // ranks
300 if (L.cur().kind == '*') {
301 dims.emplace_back(std::nullopt);
302 L.next();
303 if (L.cur().kind == ':') {
304 throw(
305 ErrorReport(L.cur()) << "Strides for unsized ranks not supported");
306 }
307 return;
308 }
309 bool shape_symbol = false;
310 if (L.cur().kind == TK_IDENT && L.cur().text() == "SS") {
311 L.next();
312 L.expect('(');
313 L.expect('-');
314 shape_symbol = true;
315 }
316 const std::string& num = L.expect(TK_NUMBER).text();
317 int64_t dim = 0;
318 try {
319 dim = std::stoll(num);
320 } catch (const std::invalid_argument& e) {
321 throw(ErrorReport(L.cur()) << "The number can't be converted to int");
322 } catch (const std::out_of_range& e) {
323 throw(ErrorReport(L.cur()) << "Number is too big");
324 }
325 if (shape_symbol) {
326 L.expect(')');
327 dim = -dim;
328 }
329 dims.emplace_back(dim);
330 });
331 if (seen_strides) {
332 at::IntArrayRef strides_ref(strides);
333 if (strides.size() != dims.size()) {
334 // note: mixing unsized ranks and ranks with strides will always trigger
335 // this
336 throw(
337 ErrorReport(L.cur())
338 << "Strides info is specified for some but not for all dimensions");
339 }
340 ptr = at::TensorType::create(
341 dtype,
342 device,
343 c10::VaryingShape<int64_t>(dims),
344 c10::VaryingShape<int64_t>(strides),
345 requires_grad);
346 } else {
347 ptr = at::TensorType::create(
348 dtype,
349 device,
350 c10::VaryingShape<int64_t>(dims),
351 c10::VaryingShape<int64_t>(dims.size()),
352 requires_grad);
353 }
354 return ptr;
355 }
356
parseType()357 std::pair<TypePtr, std::optional<AliasInfo>> SchemaTypeParser::parseType() {
358 auto r = parseFakeAndRealType();
359 return std::make_pair(std::move(std::get<0>(r)), std::move(std::get<2>(r)));
360 }
361
362 std::tuple</*fake*/ TypePtr, /*real*/ TypePtr, std::optional<AliasInfo>>
parseFakeAndRealType()363 SchemaTypeParser::parseFakeAndRealType() {
364 TypePtr fake_value;
365 TypePtr real_value;
366 std::optional<AliasInfo> alias_info;
367 // Tuple type
368 if (L.cur().kind == '(') {
369 std::vector<TypePtr> types;
370 parseList('(', ',', ')', [&] {
371 auto r = parseType();
372 types.push_back(std::move(r.first));
373 if (alias_info && r.second) {
374 alias_info->addContainedType(std::move(*r.second));
375 }
376 });
377 fake_value = real_value =
378 c10::TypeFactory::create<TupleType>(std::move(types));
379 } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
380 L.next(); // Future
381 L.expect('(');
382 auto p = parseType();
383 auto subtype = std::move(p.first);
384 auto subalias = std::move(p.second);
385 L.expect(')');
386 fake_value = real_value = c10::TypeFactory::create<FutureType>(subtype);
387 } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Await") {
388 L.next(); // Await
389 L.expect('(');
390 auto p = parseType();
391 auto subtype = std::move(p.first);
392 auto subalias = std::move(p.second);
393 L.expect(')');
394 fake_value = real_value = c10::TypeFactory::create<AwaitType>(subtype);
395 } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
396 L.next(); // RRef
397 L.expect('(');
398 auto p = parseType();
399 auto subtype = std::move(p.first);
400 auto subalias = std::move(p.second);
401 L.expect(')');
402 fake_value = real_value = c10::TypeFactory::create<RRefType>(subtype);
403 } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
404 L.next();
405 fake_value = real_value = c10::TypeFactory::get<TensorType>();
406 alias_info = parseAliasAnnotation();
407 } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
408 L.next();
409 L.expect('(');
410 auto key_type = parseType().first;
411 L.expect(',');
412 auto value_type = parseType().first;
413 L.expect(')');
414 alias_info = parseAliasAnnotation();
415 fake_value = real_value =
416 c10::TypeFactory::create<DictType>(key_type, value_type);
417 } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
418 L.next();
419 L.expect('(');
420 std::vector<TypePtr> types;
421 types.emplace_back(parseType().first);
422 while (L.cur().kind != ')') {
423 L.expect(',');
424 types.emplace_back(parseType().first);
425 }
426 L.expect(')');
427 alias_info = parseAliasAnnotation();
428 fake_value = real_value =
429 c10::TypeFactory::create<c10::UnionType>(std::move(types));
430 } else if (
431 complete_tensor_types && L.cur().kind == TK_IDENT &&
432 parseTensorDType(L.cur().text())) {
433 fake_value = real_value = parseRefinedTensor();
434 alias_info = parseAliasAnnotation();
435 } else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
436 L.next();
437 L.expect('.');
438 auto torch_tok = L.expect(TK_IDENT);
439 if (torch_tok.text() != "torch") {
440 throw(
441 ErrorReport(torch_tok.range)
442 << "Expected classes namespace but got " << torch_tok.text());
443 }
444 L.expect('.');
445 auto classes_tok = L.expect(TK_IDENT);
446 if (classes_tok.text() != "classes") {
447 throw(
448 ErrorReport(classes_tok.range)
449 << "Expected classes namespace but got " << classes_tok.text());
450 }
451 L.expect('.');
452 auto ns_tok = L.expect(TK_IDENT);
453 L.expect('.');
454 auto class_tok = L.expect(TK_IDENT);
455 fake_value = real_value = getCustomClass(
456 std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
457 class_tok.text());
458 if (!fake_value) {
459 throw(
460 ErrorReport(class_tok.range) << "Unknown custom class type "
461 << ns_tok.text() + "." + class_tok.text()
462 << ". Please ensure it is registered.");
463 }
464 } else {
465 real_value = parseBaseType();
466 if (real_value->kind() == ScalarTypeType::Kind ||
467 real_value->kind() == MemoryFormatType::Kind ||
468 real_value->kind() == LayoutType::Kind ||
469 real_value->kind() == SymIntType::Kind) {
470 fake_value = c10::TypeFactory::get<IntType>();
471 } else {
472 fake_value = real_value;
473 }
474 alias_info = parseAliasAnnotation();
475 }
476 while (true) {
477 if (L.cur().kind == '[' && L.lookahead().kind == ']') {
478 L.next(); // [
479 L.next(); // ]
480 fake_value = c10::TypeFactory::create<ListType>(fake_value);
481 real_value = c10::TypeFactory::create<ListType>(real_value);
482 auto container = parseAliasAnnotation();
483 if (alias_info) {
484 if (!container) {
485 container = std::optional<AliasInfo>(AliasInfo());
486 container->setIsWrite(alias_info->isWrite());
487 }
488 container->addContainedType(std::move(*alias_info));
489 }
490 alias_info = std::move(container);
491 } else if (L.nextIf('?')) {
492 fake_value = c10::OptionalType::get(fake_value);
493 real_value = c10::OptionalType::get(real_value);
494 } else {
495 break;
496 }
497 }
498 return std::make_tuple(
499 std::move(fake_value), std::move(real_value), std::move(alias_info));
500 }
501
parseList(int begin,int sep,int end,c10::function_ref<void ()> callback)502 void SchemaTypeParser::parseList(
503 int begin,
504 int sep,
505 int end,
506 c10::function_ref<void()> callback) {
507 auto r = L.cur().range;
508 if (begin != TK_NOTHING)
509 L.expect(begin);
510 if (L.cur().kind != end) {
511 do {
512 callback();
513 } while (L.nextIf(sep));
514 }
515 if (end != TK_NOTHING)
516 L.expect(end);
517 }
518
519 } // namespace torch::jit
520