1 #include <torch/csrc/jit/ir/irparser.h>
2
3 #include <ATen/EmptyTensor.h>
4 #include <torch/csrc/jit/frontend/lexer.h>
5 #include <torch/csrc/jit/frontend/parse_string_literal.h>
6 #include <torch/csrc/jit/frontend/schema_type_parser.h>
7 #include <torch/csrc/jit/ir/ir.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/empty_strided.h>
14 #endif
15
16 #include <string>
17 #include <vector>
18
19 namespace torch::jit {
20
21 struct VarWithType;
22 struct ParsedLiteral;
23
24 class IRParser {
25 friend void parseIR(
26 const std::string& str,
27 torch::jit::Graph* graph,
28 std::unordered_map<std::string, Value*>& vmap,
29 bool parse_tensor_constants);
IRParser(const std::string & str,torch::jit::Graph * graph,std::unordered_map<std::string,Value * > & vmap,bool parse_tensor_constants)30 IRParser(
31 const std::string& str,
32 torch::jit::Graph* graph,
33 std::unordered_map<std::string, Value*>& vmap,
34 bool parse_tensor_constants)
35 : L(std::make_shared<Source>(str)),
36 g(graph),
37 vmap(vmap),
38 type_parser(
39 L,
40 /*parse_complete_tensor_types*/ true,
41 /*allow_type_vars*/ true),
42 parse_tensor_constants_(parse_tensor_constants) {}
43
44 std::string parseVar();
45 VarWithType parseVarWithType(bool allow_optional = false);
46 ParsedLiteral parseScalarLiteral(Node* n);
47
48 void parse();
49 void parseGraphInputs();
50 void parseReturnOperator();
51
52 void parseBlocks(Node* parentNode);
53 void parseBlock(Node* parentNode);
54 void parseBlockInputs(Block* b);
55 void parseBlockOutputs(Block* b);
56
57 void parseOperatorsList(Block* b);
58 void parseOperator(Block* b);
59 void parseOperatorOutputs(std::vector<VarWithType>* outs);
60 std::string parseOperatorName();
61 void parseOperatorInputs(Node* n);
62 void parseAttrs(Node* n);
63 void parseAttr(Node* n);
64
65 void parseList(
66 int begin,
67 int sep,
68 int end,
69 const std::function<void()>& callback);
70
71 void bypassTypeAnnotationList();
72
73 Value* findValueInVMap(const std::string& name);
74
75 torch::jit::Lexer L;
76 torch::jit::Graph* g = nullptr;
77 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
78 std::unordered_map<std::string, Value*>& vmap;
79 SchemaTypeParser type_parser;
80 bool parse_tensor_constants_;
81 std::vector<Node*> deferred_tensor_value_initializations_;
82 std::vector<Node*> deferred_empty_container_initializations_;
83 };
84
85 struct ParsedLiteral {
86 ParsedLiteral() = default;
87
88 AttributeKind k = AttributeKind::t;
89
90 int64_t i = 0;
91 std::string s = "";
92 double f = 0.0;
93 c10::complex<double> c = c10::complex<double>(0, 0);
94 TypePtr ty;
95 std::vector<int64_t> is;
96 std::vector<std::string> ss;
97 std::vector<double> fs;
98 std::vector<c10::complex<double>> cs;
99 std::vector<TypePtr> tys;
100 };
101
102 struct VarWithType {
103 VarWithType() = default;
104 std::string name;
105 TypePtr type;
106 };
107
parseIR(const std::string & str,torch::jit::Graph * graph,std::unordered_map<std::string,Value * > & vmap,bool parse_tensor_constants)108 void parseIR(
109 const std::string& str,
110 torch::jit::Graph* graph,
111 std::unordered_map<std::string, Value*>& vmap,
112 bool parse_tensor_constants) {
113 torch::jit::IRParser p(str, graph, vmap, parse_tensor_constants);
114 p.parse();
115 }
116
parseIR(const std::string & str,torch::jit::Graph * graph,bool parse_tensor_constants)117 void parseIR(
118 const std::string& str,
119 torch::jit::Graph* graph,
120 bool parse_tensor_constants) {
121 std::unordered_map<std::string, Value*> vmap;
122 parseIR(str, graph, vmap, parse_tensor_constants);
123 }
124
parseVarWithType(bool allow_optional)125 VarWithType IRParser::parseVarWithType(bool allow_optional) {
126 VarWithType r;
127 r.name = parseVar();
128 if (allow_optional) {
129 r.type = nullptr;
130 } else {
131 r.type = TensorType::get();
132 }
133 if (L.nextIf(':')) {
134 auto type_alias = type_parser.parseType();
135 AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
136 r.type = type_alias.first;
137 }
138 return r;
139 }
140
parseVar()141 std::string IRParser::parseVar() {
142 L.expect('%');
143 std::string name;
144 bool continue_parsing = false;
145 do {
146 if (L.cur().kind == TK_IDENT) {
147 name += L.expect(TK_IDENT).text();
148 } else {
149 name += L.expect(TK_NUMBER).text();
150 }
151 continue_parsing = false;
152 if (L.nextIf('.')) {
153 continue_parsing = true;
154 name += '.';
155 } else if (L.cur().kind == TK_NUMBER && L.cur().text()[0] == '.') {
156 continue_parsing = true;
157 }
158 } while (continue_parsing);
159 return name;
160 }
161
parseOperatorOutputs(std::vector<VarWithType> * outs)162 void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
163 if (L.cur().kind != '%') {
164 return;
165 }
166 parseList(TK_NOTHING, ',', TK_NOTHING, [&] {
167 outs->push_back(parseVarWithType(true));
168 });
169 L.expect('=');
170 }
171
172 // Parse string or numeric literal and return it along with its type.
parseScalarLiteral(Node * n)173 ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
174 auto token = L.cur();
175 std::string str;
176 std::pair<TypePtr, std::optional<c10::AliasInfo>> type_alias;
177 ParsedLiteral r;
178 switch (token.kind) {
179 case TK_STRINGLITERAL:
180 r.k = AttributeKind::s;
181 r.s = parseStringLiteral(token.range, token.text());
182 L.next();
183 return r;
184 case '-':
185 str = "-";
186 L.next();
187 if (L.cur().kind != TK_NUMBER) {
188 throw(
189 ErrorReport(token.range)
190 << "Expected a number after '-' but got:" << token.text());
191 }
192 [[fallthrough]];
193 case TK_NUMBER:
194 str += L.cur().text();
195 if (str.find('j') != std::string::npos) {
196 r.k = AttributeKind::c;
197 double imag = 0.0f;
198 try {
199 imag = std::stod(str.substr(0, str.size() - 1));
200 } catch (const std::invalid_argument& e) {
201 throw(
202 ErrorReport(token.range)
203 << "Number cannot be converted to double");
204 } catch (const std::out_of_range& e) {
205 throw(
206 ErrorReport(token.range)
207 << "Number is too long to be represented in type double");
208 }
209 r.c = c10::complex<double>(0, imag);
210 } else if (
211 str.find('.') != std::string::npos ||
212 str.find('e') != std::string::npos) {
213 r.k = AttributeKind::f;
214 try {
215 r.f = std::stod(str);
216 } catch (const std::invalid_argument& e) {
217 throw(
218 ErrorReport(token.range)
219 << "Number cannot be converted to double");
220 } catch (const std::out_of_range& e) {
221 throw(
222 ErrorReport(token.range)
223 << "Number is too long to be represented in type double");
224 }
225 } else {
226 r.k = AttributeKind::i;
227 try {
228 r.i = std::stoll(str);
229 } catch (const std::invalid_argument& e) {
230 throw(
231 ErrorReport(token.range)
232 << "Number cannot be converted to integer");
233 } catch (const std::out_of_range& e) {
234 throw(ErrorReport(token.range) << "Number is too big");
235 }
236 }
237 L.next();
238 return r;
239 case TK_IDENT:
240 // Type literal
241 r.k = AttributeKind::ty;
242 type_alias = type_parser.parseType();
243 AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
244 r.ty = type_alias.first;
245 return r;
246 case '<': {
247 L.next();
248 auto text = L.expect(TK_IDENT);
249 if (text.text() != "Tensor") {
250 throw(
251 ErrorReport(token.range)
252 << "Could not parse literal" << token.text());
253 }
254 if (!parse_tensor_constants_) {
255 throw(
256 ErrorReport(token.range)
257 << "Tensor constant encountered but `parse_tensor_constants` set to false"
258 << token.text());
259 }
260 L.expect('>');
261 // these values will be set with randomly initialized data in
262 // a post processing pass;
263 deferred_tensor_value_initializations_.push_back(n);
264 r.k = AttributeKind::t;
265 return r;
266 }
267 case '{': {
268 L.next();
269 if (L.cur().kind == '-') {
270 L.next();
271 }
272 auto text = L.expect(TK_NUMBER);
273 if (!parse_tensor_constants_) {
274 throw(
275 ErrorReport(token.range)
276 << "Single-element tensor constant encountered but "
277 << "`parse_tensor_constants` is set to false " << token.text());
278 }
279 L.expect('}');
280 deferred_tensor_value_initializations_.push_back(n);
281 r.k = AttributeKind::t;
282 return r;
283 }
284 default:
285 throw(
286 ErrorReport(token.range)
287 << "Could not parse literal" << token.text());
288 }
289 }
290
bypassTypeAnnotationList()291 void IRParser::bypassTypeAnnotationList() {
292 int depth = 0;
293 bool bypassed_list = false;
294 while (depth != 0 || !bypassed_list) {
295 if (L.cur().kind == '[') {
296 bypassed_list = true;
297 depth++;
298 } else if (L.cur().kind == ']') {
299 depth--;
300 }
301 L.next();
302 }
303 }
304
305 /** \brief Parse attribute and add it to the node N.
306 *
307 * The function determines the attribute type (string, int, float, complex, list
308 * of strings, list of ints, list of floats, list of complex, and a list of
309 * tensors (currently only for empty lists)). An attribute looks like the
310 * following: AttrName=AttrValue Where AttrValue can be a list or a scalar
311 * literal, e.g.: size = 27 name = "Bob" coefs = [1.2, 3.4, 0.6]
312 */
parseAttr(Node * n)313 void IRParser::parseAttr(Node* n) {
314 std::string attrname = L.expect(TK_IDENT).text();
315 L.expect('=');
316 if (L.cur().kind == '[') {
317 // list
318 AttributeKind k = AttributeKind::ts;
319 c10::List<int64_t> is;
320 c10::List<std::string> ss;
321 c10::List<double> fs;
322 c10::List<c10::complex<double>> cs;
323 std::vector<TypePtr> tys;
324 int elem_num = 0;
325 parseList('[', ',', ']', [&] {
326 ParsedLiteral r = parseScalarLiteral(n);
327 switch (r.k) {
328 case AttributeKind::s:
329 ss.push_back(r.s);
330 AT_ASSERT(!elem_num++ || k == AttributeKind::ss);
331 k = AttributeKind::ss;
332 break;
333 case AttributeKind::i:
334 is.push_back(r.i);
335 AT_ASSERT(!elem_num++ || k == AttributeKind::is);
336 k = AttributeKind::is;
337 break;
338 case AttributeKind::f:
339 fs.push_back(r.f);
340 AT_ASSERT(!elem_num++ || k == AttributeKind::fs);
341 k = AttributeKind::fs;
342 break;
343 case AttributeKind::c:
344 cs.push_back(r.c);
345 AT_ASSERT(!elem_num++ || k == AttributeKind::cs);
346 k = AttributeKind::cs;
347 break;
348 case AttributeKind::ty:
349 tys.push_back(r.ty);
350 AT_ASSERT(!elem_num++ || k == AttributeKind::tys);
351 k = AttributeKind::tys;
352 break;
353 default:
354 throw(ErrorReport(L.cur().range) << "Unexpected attr type");
355 }
356 });
357 switch (k) {
358 case AttributeKind::ts:
359 n->ival_(Symbol::attr(attrname), IValue());
360 break;
361 case AttributeKind::ss:
362 n->ival_(Symbol::attr(attrname), IValue(ss));
363 break;
364 case AttributeKind::fs:
365 n->ival_(Symbol::attr(attrname), IValue(fs));
366 break;
367 case AttributeKind::cs:
368 n->ival_(Symbol::attr(attrname), IValue(cs));
369 break;
370 case AttributeKind::is:
371 n->ival_(Symbol::attr(attrname), IValue(is));
372 break;
373 case AttributeKind::tys:
374 n->tys_(Symbol::attr(attrname), tys);
375 break;
376 default:
377 throw(ErrorReport(L.cur().range) << "Unexpected attr type");
378 }
379 } else if (L.cur().text() == "annotate") {
380 L.next();
381 L.expect('(');
382 auto type = L.cur().text();
383 if (type != "List" && type != "Dict") {
384 throw(
385 ErrorReport(L.cur().range)
386 << "Unexpected annotation (only List and Dict can be parsed)");
387 }
388 L.next();
389 // ignore the annotations on the IValue constants, and instead recover
390 // type from the Node output
391 // Note: we could also use script_type_parser
392 bypassTypeAnnotationList();
393 L.expect(',');
394 // expect an empty definition (note - this isn't always true)
395 if (type == "Dict") {
396 L.expect('{');
397 L.expect('}');
398 } else if (type == "List") {
399 L.expect('[');
400 L.expect(']');
401 }
402 L.expect(')');
403 deferred_empty_container_initializations_.push_back(n);
404 } else {
405 // scalar
406 ParsedLiteral r = parseScalarLiteral(n);
407 switch (r.k) {
408 case AttributeKind::s:
409 n->s_(Symbol::attr(attrname), r.s);
410 break;
411 case AttributeKind::i:
412 n->i_(Symbol::attr(attrname), r.i);
413 break;
414 case AttributeKind::f:
415 n->f_(Symbol::attr(attrname), r.f);
416 break;
417 case AttributeKind::c:
418 n->c_(Symbol::attr(attrname), r.c);
419 break;
420 case AttributeKind::ty:
421 n->ty_(Symbol::attr(attrname), r.ty);
422 break;
423 case AttributeKind::t:
424 // initialized with random data later
425 break;
426 default:
427 throw(ErrorReport(L.cur().range) << "Unexpected attr type");
428 }
429 return;
430 }
431 }
432
parseAttrs(Node * n)433 void IRParser::parseAttrs(Node* n) {
434 parseList('[', ',', ']', [&] { parseAttr(n); });
435 }
436
parseOperatorInputs(Node * n)437 void IRParser::parseOperatorInputs(Node* n) {
438 if (L.cur().kind == '[') {
439 parseAttrs(n);
440 }
441 parseList('(', ',', ')', [&] {
442 std::string var_name = parseVar();
443 n->addInput(findValueInVMap(var_name));
444 });
445 }
446
parseBlocks(Node * parentNode)447 void IRParser::parseBlocks(Node* parentNode) {
448 L.expect(TK_INDENT);
449 while (L.cur().kind != TK_DEDENT) {
450 parseBlock(parentNode);
451 }
452 L.expect(TK_DEDENT);
453 }
454
parseBlockInputs(Block * b)455 void IRParser::parseBlockInputs(Block* b) {
456 parseList('(', ',', ')', [&] {
457 VarWithType v = parseVarWithType();
458 // If the name isn't valid, don't use it
459 std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
460 vmap[v.name] = b->addInput(uniq_name);
461 vmap[v.name]->setType(v.type);
462 });
463 }
464
parseBlockOutputs(Block * b)465 void IRParser::parseBlockOutputs(Block* b) {
466 L.expect(TK_ARROW);
467 parseList('(', ',', ')', [&] {
468 std::string var_name = parseVar();
469 b->registerOutput(findValueInVMap(var_name));
470 });
471 L.expect(TK_NEWLINE);
472 L.expect(TK_DEDENT);
473 }
474
475 /** \brief Parse a block.
476 *
477 * It should look like the following:
478 * blockName(input1, input2, input3, ...):
479 * op1
480 * op2
481 * ...
482 * opN
483 * -> (output1, output2, output3, ...)
484 */
parseBlock(Node * parentNode)485 void IRParser::parseBlock(Node* parentNode) {
486 Block* b = parentNode->addBlock();
487 L.expect(TK_IDENT).text(); // Block name is not used anywhere.
488 parseBlockInputs(b);
489 L.expect(':');
490 parseOperatorsList(b);
491 parseBlockOutputs(b);
492 }
493
494 /** \brief Parse a list of statements.
495 *
496 * It is expected to be delimited by TK_NEWLINE and end with TK_RETURN or
497 * TK_ARROW.
498 */
parseOperatorsList(Block * b)499 void IRParser::parseOperatorsList(Block* b) {
500 L.expect(TK_INDENT);
501 while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) {
502 parseOperator(b);
503 }
504 }
505
parseOperatorName()506 std::string IRParser::parseOperatorName() {
507 std::string name = L.expect(TK_IDENT).text();
508 L.expect(':');
509 L.expect(':');
510 name += "::" + L.expect(TK_IDENT).text();
511 return name;
512 }
513
514 /** \brief Parse a statement.
515 *
516 * It should look like the following:
517 * <outputs> = NodeName[<attributes>](<inputs>)
518 * <blocks>
519 * Outputs, blocks and attributes are optional.
520 */
parseOperator(Block * b)521 void IRParser::parseOperator(Block* b) {
522 // Parse lefthand side.
523 std::vector<VarWithType> outs;
524 parseOperatorOutputs(&outs);
525
526 // Parse the name and create the corresponding node in the graph.
527 auto source_range = L.cur().range;
528 std::string name = parseOperatorName();
529 Node* n = g->create(Symbol::fromQualString(name), {}, outs.size())
530 ->setSourceRange(source_range);
531
532 // Parse attributes and inputs.
533 parseOperatorInputs(n);
534
535 const FunctionSchema* schema = n->maybeSchema();
536
537 // Register outputs.
538 unsigned idx = 0;
539 for (const VarWithType& v : outs) {
540 vmap[v.name] = n->outputs()[idx];
541 if (schema && !schema->is_varret()) {
542 TORCH_CHECK(
543 schema->returns().size() > idx,
544 "Operator parsing error: out of bounds access at ",
545 idx,
546 " to schema->returns() which size is ",
547 schema->returns().size(),
548 " in size");
549 auto schema_return_type = schema->returns().at(idx).type();
550 if (!v.type) {
551 vmap[v.name]->setType(schema_return_type);
552 } else {
553 // Don't currently support checking against type variables
554 // TODO: support?
555 if (!schema_return_type->hasFreeVariables() &&
556 !v.type->isSubtypeOf(*schema_return_type)) {
557 throw(
558 ErrorReport(source_range)
559 << "Annotated type " << v.type->repr_str()
560 << " does not match schema type "
561 << schema_return_type->repr_str() << " for operator " << *schema);
562 }
563 vmap[v.name]->setType(v.type);
564 }
565 } else {
566 vmap[v.name]->setType(v.type ? v.type : TensorType::get());
567 }
568 idx++;
569 }
570
571 // Insert the new node into block B.
572 b->appendNode(n);
573
574 // If the statement has nested blocks, parse them:
575 if (L.cur().kind == TK_INDENT) {
576 parseBlocks(n);
577 }
578 L.nextIf(TK_NEWLINE);
579 }
580
parseGraphInputs()581 void IRParser::parseGraphInputs() {
582 parseList('(', ',', ')', [&] {
583 VarWithType v = parseVarWithType();
584 // If the name isn't valid, don't use it
585 std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
586 vmap[v.name] = g->addInput(uniq_name);
587 vmap[v.name]->setType(v.type);
588 });
589 }
590
591 /** \brief Parse return statement.
592 *
593 * It should look like the following:
594 * return (x : TypeX, y : TypeY, z, ...)
595 */
parseReturnOperator()596 void IRParser::parseReturnOperator() {
597 L.expect(TK_RETURN);
598
599 // Parse output names and types
600 parseList('(', ',', ')', [&] {
601 std::string var_name = parseVar();
602 g->registerOutput(findValueInVMap(var_name));
603 });
604
605 // Consume ending tokens
606 if (L.cur().kind != TK_EOF) {
607 L.expect(TK_NEWLINE);
608 L.expect(TK_DEDENT);
609 }
610 }
611
612 /** \brief Parse entire graph.
613 *
614 * It should look like the following:
615 * graphName (input1, input2, ... inputN):
616 * op1
617 * op2
618 * ...
619 * opN
620 * return (output1, output2, ... outputN)
621 */
parse()622 void IRParser::parse() {
623 // Parse graph definition, it should look like the following:
624 // graphName (input1, input2, ... inputN):
625 L.expect(TK_IDENT);
626 parseGraphInputs();
627 L.expect(':');
628
629 // After the definition we should have a list of statements, parse it:
630 parseOperatorsList(g->block());
631
632 // The last statement should be return, which specifies graph outputs
633 parseReturnOperator();
634
635 for (Node* n : deferred_tensor_value_initializations_) {
636 auto type = n->output()->type()->expect<TensorType>();
637 auto tt = n->output()->type()->cast<TensorType>();
638 TORCH_INTERNAL_ASSERT(tt, "expected tensor output ", *n);
639 auto sizes = tt->sizes().concrete_sizes();
640 TORCH_INTERNAL_ASSERT(sizes);
641 auto strides = tt->strides().concrete_sizes();
642 TORCH_INTERNAL_ASSERT(strides);
643 auto device = tt->device();
644 TORCH_INTERNAL_ASSERT(device);
645 auto dtype = tt->scalarType();
646 TORCH_INTERNAL_ASSERT(dtype);
647 auto options = at::TensorOptions(*device).dtype(dtype);
648 auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options));
649 (void)t;
650 }
651
652 for (Node* n : deferred_empty_container_initializations_) {
653 auto type = n->output()->type();
654 IValue val;
655 if (type->kind() == TypeKind::ListType) {
656 val = c10::impl::GenericList(type->containedType(0));
657 } else if (type->kind() == TypeKind::DictType) {
658 val = c10::impl::GenericDict(
659 type->containedType(0), type->containedType(1));
660 }
661 n->ival_(attr::value, val);
662 }
663 }
664
parseList(int begin,int sep,int end,const std::function<void ()> & callback)665 void IRParser::parseList(
666 int begin,
667 int sep,
668 int end,
669 const std::function<void()>& callback) {
670 if (begin != TK_NOTHING) {
671 L.expect(begin);
672 }
673 if (L.cur().kind != end) {
674 do {
675 callback();
676 } while (L.nextIf(sep));
677 }
678 if (end != TK_NOTHING) {
679 L.expect(end);
680 }
681 }
682
findValueInVMap(const std::string & name)683 Value* IRParser::findValueInVMap(const std::string& name) {
684 if (!vmap.count(name)) {
685 throw(
686 ErrorReport(L.cur().range)
687 << "Cannot find a variable with name '" << name << "'");
688 }
689 return vmap.at(name);
690 }
691
692 } // namespace torch::jit
693